From 72b19dfb597756c57086d48b4d150e2b4bdeface Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Wed, 12 May 2021 09:21:22 +0100 Subject: [PATCH 01/50] added mixin classes for pyro training and sampling --- scvi/model/base/_pyromixin.py | 593 ++++++++++++++++++++++++++++++++++ 1 file changed, 593 insertions(+) create mode 100755 scvi/model/base/_pyromixin.py diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py new file mode 100755 index 0000000000..9ac4d123f2 --- /dev/null +++ b/scvi/model/base/_pyromixin.py @@ -0,0 +1,593 @@ +import logging +from typing import Union + +from typing import Optional + +import numpy as np +import pandas as pd +import pyro +import torch +from pyro import poutine +from pyro.infer import SVI +from pytorch_lightning.callbacks import Callback +from tqdm.auto import tqdm + +from scvi.dataloaders import AnnDataLoader +from scvi.model._utils import parse_use_gpu_arg +from scvi.train import PyroTrainingPlan, Trainer + +logger = logging.getLogger(__name__) + +Number = Union[int, float] + + +class PyroJitGuideWarmup(Callback): + def __init__(self, train_dl) -> None: + super().__init__() + self.dl = train_dl + + def on_train_start(self, trainer, pl_module): + """ + Way to warmup Pyro Guide in an automated way. + Also device agnostic. + """ + + # warmup guide for JIT + pyro_guide = pl_module.module.guide + for tensors in self.dl: + tens = {k: t.to(pl_module.device) for k, t in tensors.items()} + args, kwargs = pl_module.module._get_fn_args_from_batch(tens) + pyro_guide(*args, **kwargs) + break + + +class PyroSviTrainMixin: + """ + This mixin class provides methods for: + + - training models using minibatches and using full data (copies data to GPU only once). + """ + + @property + def _plan_class(self): + return PyroTrainingPlan + + def _train_full_data( + self, + max_epochs: Optional[int] = None, + use_gpu: bool = False, + plan_kwargs: Optional[dict] = None, + lr: float = 0.01, + autoencoding_lr: Optional[float] = None, + clip_norm: float = 200, + continue_training: bool = True, + ): + """ + Private method for training using full data. + + Parameters + ---------- + max_epochs + Number of training epochs / iterations + use_gpu + Bool, use gpu? + plan_kwargs + Training plan arguments such as optim and loss_fn + continue_training + When the model is already trained, should calling .train() continue training? (False = restart training) + + Returns + ------- + ELBO history in self.module.history_ + + """ + + args, kwargs = self.module.model._get_fn_args_full_data(self.adata) + gpus, device = parse_use_gpu_arg(use_gpu) + + args = [a.to(device) for a in args] + kwargs = {k: v.to(device) for k, v in kwargs.items()} + self.to_device(device) + + if not continue_training or not self.is_trained_: + # models share param store, make sure it is cleared before training + pyro.clear_param_store() + # initialise guide params (warmup) + self.module.guide(*args, **kwargs) + + svi = SVI( + self.module.model, + self.module.guide, + # select optimiser, optionally choosing different lr for autoencoding guide + pyro.optim.ClippedAdam(self._optim_param(lr, autoencoding_lr, clip_norm)), + loss=plan_kwargs["loss_fn"], + ) + + iter_iterator = tqdm(range(max_epochs)) + hist = [] + for it in iter_iterator: + + loss = svi.step(*args, **kwargs) + iter_iterator.set_description( + "Epoch " + "{:d}".format(it) + ", -ELBO: " + "{:.4e}".format(loss) + ) + hist.append(loss) + + if it % 500 == 0: + torch.cuda.empty_cache() + + if continue_training and self.is_trained_: + # add ELBO listory + hist = self.module.history_ + hist + self.module.history_ = hist + self.module.is_trained_ = True + self.history_ = hist + self.is_trained_ = True + + def _train_minibatch( + self, + max_epochs: Optional[int] = None, + max_steps: Optional[int] = None, + use_gpu: bool = False, + plan_kwargs: Optional[dict] = None, + trainer_kwargs: Optional[dict] = None, + lr: float = 0.01, + optim_kwargs: Optional[dict] = None, + early_stopping: bool = False, + continue_training: bool = True, + ): + """ + Private method for training using minibatches (scVI interface and pytorch lightning). + + Parameters + ---------- + max_epochs + Number of training epochs / iterations + max_steps + Number of training steps + use_gpu + Bool, use gpu? + plan_kwargs + Training plan arguments such as optim and loss_fn + trainer_kwargs + Arguments for scvi.train.Trainer. + optim_kwargs + optimiser creation arguments to such as autoencoding_lr, clip_norm, module_names + early_stopping + Bool, use early stopping? (not tested) + continue_training + When the model is already trained, should calling .train() continue training? (False = restart training) + + Returns + ------- + ELBO history in self.module.history_ + + """ + + if not continue_training or not self.is_trained_: + # models share param store, make sure it is cleared before training + pyro.clear_param_store() + + gpus, device = parse_use_gpu_arg(use_gpu) + if max_epochs is None: + n_obs = self.adata.n_obs + max_epochs = np.min([round((20000 / n_obs) * 400), 400]) + + plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() + trainer_kwargs = trainer_kwargs if isinstance(trainer_kwargs, dict) else dict() + optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else dict() + + batch_size = self.module.model.batch_size + # select optimiser, optionally choosing different lr for different parameters + plan_kwargs["optim"] = pyro.optim.ClippedAdam( + self._optim_param(lr, **optim_kwargs) + ) + + # create data loader for training + train_dl = AnnDataLoader(self.adata, shuffle=True, batch_size=batch_size) + plan = PyroTrainingPlan(self.module, **plan_kwargs) + es = "early_stopping" + trainer_kwargs[es] = ( + early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es] + ) + trainer = Trainer( + gpus=gpus, + max_epochs=max_epochs, + max_steps=max_steps, + callbacks=[PyroJitGuideWarmup(train_dl)], + **trainer_kwargs + ) + trainer.fit(plan, train_dl) + self.module.to(device) + + try: + if continue_training and self.is_trained_: + # add ELBO listory + index = range( + len(self.module.history_), + len(self.module.history_) + + len(trainer.logger.history["train_loss_epoch"]), + ) + trainer.logger.history["train_loss_epoch"].index = index + self.module.history_ = pd.concat( + [self.module.history_, trainer.logger.history["train_loss_epoch"]] + ) + else: + self.module.history_ = trainer.logger.history["train_loss_epoch"] + self.history_ = self.module.history_ + except AttributeError: + self.history_ = None + + self.module.is_trained_ = True + self.is_trained_ = True + + def _optim_param(self, lr: float = 0.01, clip_norm: float = 200): + # create function which fetches different lr for different parameters + def optim_param(module_name, param_name): + return { + "lr": lr, + # limit the gradient step from becoming too large + "clip_norm": clip_norm, + } + + return optim_param + + def train(self, **kwargs): + """ + Train the model. + + Parameters + ---------- + max_epochs + Number of training epochs / iterations + max_steps + Number of training steps + use_gpu + Bool, use gpu? + lr + Learning rate. + autoencoding_lr + Optional, a separate learning rate for encoder network. + clip_norm + Gradient clipping norm (useful for preventing exploding gradients, + which can lead to impossible values and NaN loss). + trainer_kwargs + Training plan arguments for scvi.train.PyroTrainingPlan (Excluding optim and loss_fn) + early_stopping + Bool, use early stopping? (not tested) + + Returns + ------- + ELBO history in self.module.history_ + + """ + + plan_kwargs = {"loss_fn": pyro.infer.Trace_ELBO()} + + batch_size = self.module.model.batch_size + + if batch_size is None: + # train using full data (faster for small datasets) + self._train_full_data(plan_kwargs=plan_kwargs, **kwargs) + else: + # standard training using minibatches + self._train_minibatch(plan_kwargs=plan_kwargs, **kwargs) + + +class PyroSampleMixin: + """ + This mixin class provides methods for: + + - generating samples from posterior distribution using minibatches and full data + """ + + def _get_one_posterior_sample( + self, + args, + kwargs, + return_sites: Optional[list] = None, + sample_observed: bool = False, + ): + """Get one sample from posterior distribution. + + Parameters + ---------- + args + arguments to model and guide + kwargs + arguments to model and guide + + Returns + ------- + Dictionary with a sample for each variable + + """ + + guide_trace = poutine.trace(self.module.guide).get_trace(*args, **kwargs) + model_trace = poutine.trace( + poutine.replay(self.module.model, guide_trace) + ).get_trace(*args, **kwargs) + + sample = { + name: site["value"].detach().cpu().numpy() + for name, site in model_trace.nodes.items() + if ( + (site["type"] == "sample") # sample statement + and ( + (return_sites is None) or (name in return_sites) + ) # selected in return_sites list + and ( + ( + (not site.get("is_observed", True)) or sample_observed + ) # don't save observed + or (site.get("infer", False).get("_deterministic", False)) + ) # unless it is deterministic + and not isinstance( + site.get("fn", None), poutine.subsample_messenger._Subsample + ) # don't save plates + ) + } + + return sample + + def _get_posterior_samples( + self, + args, + kwargs, + num_samples: int = 1000, + return_sites: Optional[list] = None, + return_observed: bool = False, + show_progress: bool = True, + ): + """ + Get many samples from posterior distribution. + + Parameters + ---------- + args + arguments to model and guide + kwargs + arguments to model and guide + show_progress + show progress bar + + Returns + ------- + Dictionary with array of samples for each variable + dictionary {variable_name: [array with samples in 0 dimension]} + + """ + + samples = self._get_one_posterior_sample( + args, kwargs, return_sites=return_sites, sample_observed=return_observed + ) + samples = {k: [v] for k, v in samples.items()} + + for _ in tqdm( + range(1, num_samples), + disable=not show_progress, + desc="Sampling global variables, sample: ", + ): + # generate new sample + samples_ = self._get_one_posterior_sample( + args, kwargs, return_sites=return_sites, sample_observed=return_observed + ) + + # add new sample + samples = {k: samples[k] + [samples_[k]] for k in samples.keys()} + + return {k: np.array(v) for k, v in samples.items()} + + def _posterior_samples_full_data(self, use_gpu: bool = True, **sample_kwargs): + """ + Generate samples from posterior distribution using all data + + Parameters + ---------- + sample_kwargs + arguments to _get_posterior_samples + + Returns + ------- + dictionary {variable_name: [array with samples in 0 dimension]} + + """ + + self.module.eval() + gpus, device = parse_use_gpu_arg(use_gpu) + + args, kwargs = self.module.model._get_fn_args_full_data(self.adata) + args = [a.to(device) for a in args] + kwargs = {k: v.to(device) for k, v in kwargs.items()} + self.to_device(device) + + samples = self._get_posterior_samples(args, kwargs, **sample_kwargs) + + return samples + + def _posterior_samples_minibatch(self, use_gpu: bool = True, **sample_kwargs): + """ + Generate samples of the posterior distribution of each parameter, separating local (minibatch) variables + and global variables, which is necessary when performing minibatch inference. + + Note for developers: requires model class method which lists observation/minibatch plate + variables (self.module.model.list_obs_plate_vars()). + + Parameters + ---------- + use_gpu + Bool, use gpu? + + Returns + ------- + dictionary {variable_name: [array with samples in 0 dimension]} + + """ + + gpus, device = parse_use_gpu_arg(use_gpu) + + self.module.eval() + + train_dl = AnnDataLoader( + self.adata, shuffle=False, batch_size=self.module.model.batch_size + ) + # sample local parameters + i = 0 + with tqdm(train_dl, desc="Sampling local variables, batch: ") as tqdm_dl: + for tensor_dict in tqdm_dl: + if i == 0: + args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) + args = [a.to(device) for a in args] + kwargs = {k: v.to(device) for k, v in kwargs.items()} + self.to_device(device) + + # check whether any variable requested in return_sites are in obs_plate + sample_kwargs_obs_plate = sample_kwargs.copy() + if ("return_sites" in sample_kwargs.keys()) and ( + sample_kwargs["return_sites"] is not None + ): + return_sites = np.array(sample_kwargs["return_sites"]) + return_sites = return_sites[ + np.isin( + return_sites, + list( + self.module.model.list_obs_plate_vars()[ + "sites" + ].keys() + ), + ) + ] + if len(return_sites) == 0: + sample_kwargs_obs_plate["return_sites"] = [return_sites] + else: + sample_kwargs_obs_plate["return_sites"] = list(return_sites) + else: + sample_kwargs_obs_plate["return_sites"] = list( + self.module.model.list_obs_plate_vars()["sites"].keys() + ) + sample_kwargs_obs_plate["show_progress"] = False + samples = self._get_posterior_samples( + args, kwargs, **sample_kwargs_obs_plate + ) + + # find plate dimension + trace = poutine.trace(self.module.model).get_trace(*args, **kwargs) + obs_plate = { + name: site["cond_indep_stack"][0].dim + for name, site in trace.nodes.items() + if site["type"] == "sample" + if any( + f.name == self.module.model.list_obs_plate_vars()["name"] + for f in site["cond_indep_stack"] + ) + } + + else: + args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) + args = [a.to(device) for a in args] + kwargs = {k: v.to(device) for k, v in kwargs.items()} + self.to_device(device) + + samples_ = self._get_posterior_samples( + args, kwargs, **sample_kwargs_obs_plate + ) + samples = { + k: np.array( + [ + np.concatenate( + [samples[k][i], samples_[k][i]], + axis=list(obs_plate.values())[0], + ) + for i in range(len(samples[k])) + ] + ) + for k in samples.keys() + } + i += 1 + + # sample global parameters + i = 0 + for tensor_dict in train_dl: + if i == 0: + args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) + args = [a.to(device) for a in args] + kwargs = {k: v.to(device) for k, v in kwargs.items()} + self.to_device(device) + + global_samples = self._get_posterior_samples( + args, kwargs, **sample_kwargs + ) + global_samples = { + k: global_samples[k] + for k in global_samples.keys() + if k not in self.module.model.list_obs_plate_vars()["sites"] + } + i += 1 + + for k in global_samples.keys(): + samples[k] = global_samples[k] + + self.module.to(device) + + return samples + + def sample_posterior( + self, + num_samples: int = 1000, + return_sites: Optional[list] = None, + use_gpu: bool = False, + sample_kwargs=None, + return_samples: bool = False, + ): + """ + Generate samples from posterior distribution for each parameter + + Parameters + ---------- + num_samples + number of posterior samples to generate. + return_sites + get samples for pyro model variable, default is all variables, otherwise list variable names). + use_gpu + Use gpu? + sample_kwargs + dictionary with arguments to _get_posterior_samples (see below): + return_observed + return observed sites/variables? + return_samples + return samples in addition to sample mean, 5%/95% quantile and SD? + + Returns + ------- + Posterior distribution samples, a dictionary for each of (mean, 5% quantile, SD, optionally all samples), + containing dictionaries for each variable with numpy arrays. + Dictionary of all samples contains samples for each variable as numpy arrays of shape ``(n_samples, ...)`` + + """ + + sample_kwargs = sample_kwargs if isinstance(sample_kwargs, dict) else dict() + sample_kwargs["num_samples"] = num_samples + sample_kwargs["return_sites"] = return_sites + + if self.module.model.batch_size is None: + # sample using full data + samples = self._posterior_samples_full_data( + use_gpu=use_gpu, **sample_kwargs + ) + else: + # sample using minibatches + samples = self._posterior_samples_minibatch( + use_gpu=use_gpu, **sample_kwargs + ) + + param_names = list(samples.keys()) + results = dict() + if return_samples: + results["posterior_samples"] = samples + + results["post_sample_means"] = {v: samples[v].mean(axis=0) for v in param_names} + results["post_sample_q05"] = self.posterior_quantile(q=0.05, use_gpu=use_gpu) + results["post_sample_q95"] = self.posterior_quantile(q=0.95, use_gpu=use_gpu) + results["post_sample_sds"] = {v: samples[v].std(axis=0) for v in param_names} + + return results From 78bcb473ee435e89027c2f2afd1e85fc45dc82be Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Wed, 12 May 2021 12:47:57 +0100 Subject: [PATCH 02/50] added tests --- scvi/model/base/__init__.py | 4 + scvi/model/base/_pyromixin.py | 28 +++--- tests/models/test_pyro.py | 175 ++++++++++++++++++++++++++-------- 3 files changed, 151 insertions(+), 56 deletions(-) diff --git a/scvi/model/base/__init__.py b/scvi/model/base/__init__.py index ce7fb446b1..83e5b6ea14 100644 --- a/scvi/model/base/__init__.py +++ b/scvi/model/base/__init__.py @@ -1,5 +1,6 @@ from ._archesmixin import ArchesMixin from ._base_model import BaseModelClass +from ._pyromixin import PyroJitGuideWarmup, PyroSampleMixin, PyroSviTrainMixin from ._rnamixin import RNASeqMixin from ._training_mixin import UnsupervisedTrainingMixin from ._vaemixin import VAEMixin @@ -10,4 +11,7 @@ "RNASeqMixin", "VAEMixin", "UnsupervisedTrainingMixin", + "PyroSviTrainMixin", + "PyroSampleMixin", + "PyroJitGuideWarmup", ] diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 9ac4d123f2..8599a89b2d 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -1,7 +1,5 @@ import logging -from typing import Union - -from typing import Optional +from typing import Optional, Union import numpy as np import pandas as pd @@ -58,8 +56,7 @@ def _train_full_data( use_gpu: bool = False, plan_kwargs: Optional[dict] = None, lr: float = 0.01, - autoencoding_lr: Optional[float] = None, - clip_norm: float = 200, + optim_kwargs: Optional[dict] = None, continue_training: bool = True, ): """ @@ -84,6 +81,7 @@ def _train_full_data( args, kwargs = self.module.model._get_fn_args_full_data(self.adata) gpus, device = parse_use_gpu_arg(use_gpu) + optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else dict() args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} @@ -99,7 +97,7 @@ def _train_full_data( self.module.model, self.module.guide, # select optimiser, optionally choosing different lr for autoencoding guide - pyro.optim.ClippedAdam(self._optim_param(lr, autoencoding_lr, clip_norm)), + pyro.optim.ClippedAdam(self._optim_param(lr, **optim_kwargs)), loss=plan_kwargs["loss_fn"], ) @@ -177,7 +175,7 @@ def _train_minibatch( trainer_kwargs = trainer_kwargs if isinstance(trainer_kwargs, dict) else dict() optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else dict() - batch_size = self.module.model.batch_size + batch_size = self.batch_size # select optimiser, optionally choosing different lr for different parameters plan_kwargs["optim"] = pyro.optim.ClippedAdam( self._optim_param(lr, **optim_kwargs) @@ -264,7 +262,7 @@ def train(self, **kwargs): plan_kwargs = {"loss_fn": pyro.infer.Trace_ELBO()} - batch_size = self.module.model.batch_size + batch_size = self.batch_size if batch_size is None: # train using full data (faster for small datasets) @@ -428,9 +426,7 @@ def _posterior_samples_minibatch(self, use_gpu: bool = True, **sample_kwargs): self.module.eval() - train_dl = AnnDataLoader( - self.adata, shuffle=False, batch_size=self.module.model.batch_size - ) + train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=self.batch_size) # sample local parameters i = 0 with tqdm(train_dl, desc="Sampling local variables, batch: ") as tqdm_dl: @@ -569,7 +565,7 @@ def sample_posterior( sample_kwargs["num_samples"] = num_samples sample_kwargs["return_sites"] = return_sites - if self.module.model.batch_size is None: + if self.batch_size is None: # sample using full data samples = self._posterior_samples_full_data( use_gpu=use_gpu, **sample_kwargs @@ -586,8 +582,12 @@ def sample_posterior( results["posterior_samples"] = samples results["post_sample_means"] = {v: samples[v].mean(axis=0) for v in param_names} - results["post_sample_q05"] = self.posterior_quantile(q=0.05, use_gpu=use_gpu) - results["post_sample_q95"] = self.posterior_quantile(q=0.95, use_gpu=use_gpu) + results["post_sample_q05"] = { + v: np.quantile(samples[v], 0.05, axis=0) for v in param_names + } + results["post_sample_q95"] = { + v: np.quantile(samples[v], 0.95, axis=0) for v in param_names + } results["post_sample_sds"] = {v: samples[v].std(axis=0) for v in param_names} return results diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index b280dce9c6..08ba52773c 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -5,46 +5,31 @@ import pyro.distributions as dist import torch import torch.nn as nn -from pyro.infer.autoguide import AutoDiagonalNormal +from anndata import AnnData +from pyro.infer.autoguide import AutoNormal, init_to_mean from pyro.nn import PyroModule, PyroSample -from pytorch_lightning.callbacks import Callback +from scipy.sparse import issparse from scvi import _CONSTANTS -from scvi.data import synthetic_iid +from scvi.data import register_tensor_from_anndata, synthetic_iid +from scvi.data._anndata import get_from_registry from scvi.dataloaders import AnnDataLoader +from scvi.model.base import ( + BaseModelClass, + PyroJitGuideWarmup, + PyroSampleMixin, + PyroSviTrainMixin, +) from scvi.module.base import PyroBaseModuleClass from scvi.train import PyroTrainingPlan, Trainer -class PyroJitGuideWarmup(Callback): - def __init__(self, train_dl) -> None: - super().__init__() - self.dl = train_dl - - def on_train_start(self, trainer, pl_module): - """ - Way to warmup Pyro Guide in an automated way. - - Also device agnostic. - """ - - # warmup guide for JIT - pyro_model = pl_module.module.model - dev = pyro_model.linear.weight.device - pyro_guide = pl_module.module.guide - for tensors in self.dl: - tens = {k: t.to(dev) for k, t in tensors.items()} - args, kwargs = pl_module.module._get_fn_args_from_batch(tens) - pyro_guide(*args, **kwargs) - break - - class BayesianRegressionPyroModel(PyroModule): - def __init__(self, in_features, out_features): + def __init__(self, in_features, out_features, n_obs): super().__init__() self.in_features = in_features self.out_features = out_features - self.n_obs = None + self.n_obs = n_obs self.register_buffer("zero", torch.tensor(0.0)) self.register_buffer("one", torch.tensor(1.0)) @@ -62,20 +47,58 @@ def __init__(self, in_features, out_features): .to_event(1) ) - def forward(self, x, y): - sigma = pyro.sample("sigma", dist.Uniform(self.zero, self.ten)) + def create_plates(self, x, y, ind_x): + return pyro.plate("data", size=self.n_obs, dim=-2, subsample=ind_x) + + def list_obs_plate_vars(self): + """Create a dictionary with the name of observation/minibatch plate, + indexes of model args to provide to encoder, + variable names that belong to the observation plate + and the number of dimensions in non-plate axis of each variable""" + + return { + "name": "obs_plate", + "in": [0], # index for expression data + "sites": {}, + } + + @staticmethod + def _get_fn_args_from_batch(tensor_dict): + x = tensor_dict[_CONSTANTS.X_KEY] + y = tensor_dict[_CONSTANTS.LABELS_KEY] + ind_x = tensor_dict["ind_x"].long().squeeze() + return (x, y, ind_x), {} + + @staticmethod + def _get_fn_args_full_data(adata): + x = get_from_registry(adata, _CONSTANTS.X_KEY) + if issparse(x): + x = np.asarray(x.toarray()) + x = torch.tensor(x.astype("float32")) + ind_x = torch.tensor(get_from_registry(adata, "ind_x")) + y = torch.tensor(get_from_registry(adata, _CONSTANTS.LABELS_KEY)) + return (x, y, ind_x), {} + + def forward(self, x, y, ind_x): + + obs_plate = self.create_plates(x, y, ind_x) + + sigma = pyro.sample("sigma", dist.Exponential(self.one)) mean = self.linear(x).squeeze(-1) - with pyro.plate("data", size=self.n_obs, subsample_size=x.shape[0]): + with obs_plate: pyro.sample("obs", dist.Normal(mean, sigma), obs=y) return mean class BayesianRegressionModule(PyroBaseModuleClass): - def __init__(self, in_features, out_features): + def __init__(self, **kwargs): super().__init__() - self._model = BayesianRegressionPyroModel(in_features, out_features) - self._guide = AutoDiagonalNormal(self.model) + self._model = BayesianRegressionPyroModel(**kwargs) + self._guide = AutoNormal( + self.model, init_loc_fn=init_to_mean, create_plates=self.model.create_plates + ) + self._get_fn_args_from_batch = self._model._get_fn_args_from_batch @property def model(self): @@ -85,20 +108,48 @@ def model(self): def guide(self): return self._guide - @staticmethod - def _get_fn_args_from_batch(tensor_dict): - x = tensor_dict[_CONSTANTS.X_KEY] - y = tensor_dict[_CONSTANTS.LABELS_KEY] - return (x, y.squeeze(1)), {} +class BayesianRegressionModel(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass): + def __init__( + self, + adata: AnnData, + batch_size=None, + ): + # add index for each cell (provided to pyro plate for correct minibatching) + adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") + register_tensor_from_anndata( + adata, + registry_key="ind_x", + adata_attr_name="obs", + adata_key_name="_indices", + ) + + super().__init__(adata) + + self.batch_size = batch_size + self.module = BayesianRegressionModule( + in_features=adata.shape[1], out_features=1, n_obs=adata.n_obs + ) + self._model_summary_string = "BayesianRegressionModel" + self.init_params_ = self._get_init_params(locals()) def test_pyro_bayesian_regression(save_path): use_gpu = int(torch.cuda.is_available()) adata = synthetic_iid() + # add index for each cell (provided to pyro plate for correct minibatching) + adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") + register_tensor_from_anndata( + adata, + registry_key="ind_x", + adata_attr_name="obs", + adata_key_name="_indices", + ) train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() - model = BayesianRegressionModule(adata.shape[1], 1) + model = BayesianRegressionModule( + in_features=adata.shape[1], out_features=1, n_obs=adata.n_obs + ) plan = PyroTrainingPlan(model, n_obs=len(train_dl.indices)) trainer = Trainer( gpus=use_gpu, @@ -156,9 +207,19 @@ def test_pyro_bayesian_regression(save_path): def test_pyro_bayesian_regression_jit(): use_gpu = int(torch.cuda.is_available()) adata = synthetic_iid() + # add index for each cell (provided to pyro plate for correct minibatching) + adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") + register_tensor_from_anndata( + adata, + registry_key="ind_x", + adata_attr_name="obs", + adata_key_name="_indices", + ) train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() - model = BayesianRegressionModule(adata.shape[1], 1) + model = BayesianRegressionModule( + in_features=adata.shape[1], out_features=1, n_obs=adata.n_obs + ) train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) plan = PyroTrainingPlan( model, loss_fn=pyro.infer.JitTrace_ELBO(), n_obs=len(train_dl.indices) @@ -169,7 +230,7 @@ def test_pyro_bayesian_regression_jit(): trainer.fit(plan, train_dl) # 100 features, 1 for sigma, 1 for bias - assert list(model.guide.parameters())[0].shape[0] == 102 + # assert list(model.guide.parameters())[0].shape[0] == 102 if use_gpu == 1: model.cuda() @@ -184,3 +245,33 @@ def test_pyro_bayesian_regression_jit(): for k, v in predictive(*args, **kwargs).items() if k != "obs" } + + +def test_pyro_bayesian_train_sample_mixin(): + use_gpu = torch.cuda.is_available() + adata = synthetic_iid() + mod = BayesianRegressionModel(adata, batch_size=128) + mod.train(max_epochs=2, lr=0.01, use_gpu=use_gpu) + + # 100 features, 1 for sigma, 1 for bias + # assert list(mod.module.guide.parameters())[0].shape[0] == 102 + + # test posterior sampling + samples = mod.sample_posterior(num_samples=10, use_gpu=use_gpu, return_samples=True) + + assert len(samples["posterior_samples"]["sigma"]) == 10 + + +def test_pyro_bayesian_train_sample_mixin_full_data(): + use_gpu = torch.cuda.is_available() + adata = synthetic_iid() + mod = BayesianRegressionModel(adata, batch_size=None) + mod.train(max_epochs=2, lr=0.01, use_gpu=use_gpu) + + # 100 features, 1 for sigma, 1 for bias + # assert list(mod.module.guide.parameters())[0].shape[0] == 102 + + # test posterior sampling + samples = mod.sample_posterior(num_samples=10, use_gpu=use_gpu, return_samples=True) + + assert len(samples["posterior_samples"]["sigma"]) == 10 From 865042721b4e6b2052224937d12ed26536e7b530 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Wed, 12 May 2021 22:20:04 +0100 Subject: [PATCH 03/50] added test with local variables --- tests/models/test_pyro.py | 47 ++++++++++++++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index 08ba52773c..0212754401 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -25,12 +25,14 @@ class BayesianRegressionPyroModel(PyroModule): - def __init__(self, in_features, out_features, n_obs): + def __init__(self, in_features, out_features, n_obs, per_cell_weight=False): super().__init__() self.in_features = in_features self.out_features = out_features self.n_obs = n_obs + self.per_cell_weight = per_cell_weight + self.register_buffer("zero", torch.tensor(0.0)) self.register_buffer("one", torch.tensor(1.0)) self.register_buffer("ten", torch.tensor(10.0)) @@ -48,7 +50,7 @@ def __init__(self, in_features, out_features, n_obs): ) def create_plates(self, x, y, ind_x): - return pyro.plate("data", size=self.n_obs, dim=-2, subsample=ind_x) + return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=ind_x) def list_obs_plate_vars(self): """Create a dictionary with the name of observation/minibatch plate, @@ -59,7 +61,7 @@ def list_obs_plate_vars(self): return { "name": "obs_plate", "in": [0], # index for expression data - "sites": {}, + "sites": {"per_cell_weights": 1}, } @staticmethod @@ -84,7 +86,16 @@ def forward(self, x, y, ind_x): obs_plate = self.create_plates(x, y, ind_x) sigma = pyro.sample("sigma", dist.Exponential(self.one)) + mean = self.linear(x).squeeze(-1) + + if self.per_cell_weight: + with obs_plate: + per_cell_weights = pyro.sample( + "per_cell_weights", dist.Normal(self.zero, self.one) + ) + mean = mean + per_cell_weights.squeeze(-1) + with obs_plate: pyro.sample("obs", dist.Normal(mean, sigma), obs=y) return mean @@ -114,6 +125,7 @@ def __init__( self, adata: AnnData, batch_size=None, + per_cell_weight=False, ): # add index for each cell (provided to pyro plate for correct minibatching) adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") @@ -128,7 +140,10 @@ def __init__( self.batch_size = batch_size self.module = BayesianRegressionModule( - in_features=adata.shape[1], out_features=1, n_obs=adata.n_obs + in_features=adata.shape[1], + out_features=1, + n_obs=adata.n_obs, + per_cell_weight=per_cell_weight, ) self._model_summary_string = "BayesianRegressionModel" self.init_params_ = self._get_init_params(locals()) @@ -180,7 +195,9 @@ def test_pyro_bayesian_regression(save_path): torch.save(model.state_dict(), model_save_path) pyro.clear_param_store() - new_model = BayesianRegressionModule(adata.shape[1], 1) + new_model = BayesianRegressionModule( + in_features=adata.shape[1], out_features=1, n_obs=adata.n_obs + ) # run model one step to get autoguide params try: new_model.load_state_dict(torch.load(model_save_path)) @@ -275,3 +292,23 @@ def test_pyro_bayesian_train_sample_mixin_full_data(): samples = mod.sample_posterior(num_samples=10, use_gpu=use_gpu, return_samples=True) assert len(samples["posterior_samples"]["sigma"]) == 10 + + +def test_pyro_bayesian_train_sample_mixin_with_local(): + use_gpu = torch.cuda.is_available() + adata = synthetic_iid() + mod = BayesianRegressionModel(adata, batch_size=128, per_cell_weight=True) + mod.train(max_epochs=2, lr=0.01, use_gpu=use_gpu) + + # 100 features, 1 for sigma, 1 for bias + # assert list(mod.module.guide.parameters())[0].shape[0] == 102 + + # test posterior sampling + samples = mod.sample_posterior(num_samples=10, use_gpu=use_gpu, return_samples=True) + + assert len(samples["posterior_samples"]["sigma"]) == 10 + assert samples["posterior_samples"]["per_cell_weights"].shape == ( + 10, + adata.n_obs, + 1, + ) From a9ad0d22837eeba1170bb22bf7c24f993406d693 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Fri, 14 May 2021 09:33:47 +0100 Subject: [PATCH 04/50] refactored posterior sampling, replaced tqdm, other minor changes --- scvi/model/base/_pyromixin.py | 195 +++++++++++++++++----------------- tests/models/test_pyro.py | 19 ++-- 2 files changed, 107 insertions(+), 107 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 8599a89b2d..d37eb9f96c 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -8,11 +8,11 @@ from pyro import poutine from pyro.infer import SVI from pytorch_lightning.callbacks import Callback -from tqdm.auto import tqdm from scvi.dataloaders import AnnDataLoader from scvi.model._utils import parse_use_gpu_arg from scvi.train import PyroTrainingPlan, Trainer +from scvi.utils import track logger = logging.getLogger(__name__) @@ -46,10 +46,6 @@ class PyroSviTrainMixin: - training models using minibatches and using full data (copies data to GPU only once). """ - @property - def _plan_class(self): - return PyroTrainingPlan - def _train_full_data( self, max_epochs: Optional[int] = None, @@ -87,6 +83,11 @@ def _train_full_data( kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) + if hasattr(self.module.model, "n_obs"): + setattr(self.module.model, "n_obs", self.adata.n_obs) + if hasattr(self.module.guide, "n_obs"): + setattr(self.module.guide, "n_obs", self.adata.n_obs) + if not continue_training or not self.is_trained_: # models share param store, make sure it is cleared before training pyro.clear_param_store() @@ -101,7 +102,10 @@ def _train_full_data( loss=plan_kwargs["loss_fn"], ) - iter_iterator = tqdm(range(max_epochs)) + iter_iterator = track( + range(max_epochs), + style="tqdm", + ) hist = [] for it in iter_iterator: @@ -172,6 +176,7 @@ def _train_minibatch( max_epochs = np.min([round((20000 / n_obs) * 400), 400]) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() + plan_kwargs["n_obs"] = self.adata.n_obs trainer_kwargs = trainer_kwargs if isinstance(trainer_kwargs, dict) else dict() optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else dict() @@ -192,7 +197,7 @@ def _train_minibatch( gpus=gpus, max_epochs=max_epochs, max_steps=max_steps, - callbacks=[PyroJitGuideWarmup(train_dl)], + # callbacks=[PyroJitGuideWarmup(train_dl)], **trainer_kwargs ) trainer.fit(plan, train_dl) @@ -279,6 +284,7 @@ class PyroSampleMixin: - generating samples from posterior distribution using minibatches and full data """ + @torch.no_grad() def _get_one_posterior_sample( self, args, @@ -307,7 +313,7 @@ def _get_one_posterior_sample( ).get_trace(*args, **kwargs) sample = { - name: site["value"].detach().cpu().numpy() + name: site["value"].cpu().numpy() for name, site in model_trace.nodes.items() if ( (site["type"] == "sample") # sample statement @@ -317,7 +323,7 @@ def _get_one_posterior_sample( and ( ( (not site.get("is_observed", True)) or sample_observed - ) # don't save observed + ) # don't save observed unless requested or (site.get("infer", False).get("_deterministic", False)) ) # unless it is deterministic and not isinstance( @@ -361,11 +367,13 @@ def _get_posterior_samples( ) samples = {k: [v] for k, v in samples.items()} - for _ in tqdm( + for _ in track( range(1, num_samples), + style="tqdm", + description="Sampling global variables, sample: ", disable=not show_progress, - desc="Sampling global variables, sample: ", ): + # generate new sample samples_ = self._get_one_posterior_sample( args, kwargs, return_sites=return_sites, sample_observed=return_observed @@ -403,6 +411,36 @@ def _posterior_samples_full_data(self, use_gpu: bool = True, **sample_kwargs): return samples + def _check_obs_plate_return_sites(self, sample_kwargs): + # check whether any variable requested in return_sites are in obs_plate + obs_plate_sites = list(self.module.model.list_obs_plate_vars()["sites"].keys()) + if ("return_sites" in sample_kwargs.keys()) and ( + sample_kwargs["return_sites"] is not None + ): + return_sites = np.array(sample_kwargs["return_sites"]) + return_sites = return_sites[np.isin(return_sites, obs_plate_sites)] + if len(return_sites) == 0: + return [return_sites] + else: + return list(return_sites) + else: + return obs_plate_sites + + def _find_plate_dimension(self, args, kwargs): + + # find plate dimension + trace = poutine.trace(self.module.model).get_trace(*args, **kwargs) + obs_plate = { + name: site["cond_indep_stack"][0].dim + for name, site in trace.nodes.items() + if site["type"] == "sample" + if any( + f.name == self.module.model.list_obs_plate_vars()["name"] + for f in site["cond_indep_stack"] + ) + } + return obs_plate + def _posterior_samples_minibatch(self, use_gpu: bool = True, **sample_kwargs): """ Generate samples of the posterior distribution of each parameter, separating local (minibatch) variables @@ -429,96 +467,63 @@ def _posterior_samples_minibatch(self, use_gpu: bool = True, **sample_kwargs): train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=self.batch_size) # sample local parameters i = 0 - with tqdm(train_dl, desc="Sampling local variables, batch: ") as tqdm_dl: - for tensor_dict in tqdm_dl: - if i == 0: - args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) - args = [a.to(device) for a in args] - kwargs = {k: v.to(device) for k, v in kwargs.items()} - self.to_device(device) - - # check whether any variable requested in return_sites are in obs_plate - sample_kwargs_obs_plate = sample_kwargs.copy() - if ("return_sites" in sample_kwargs.keys()) and ( - sample_kwargs["return_sites"] is not None - ): - return_sites = np.array(sample_kwargs["return_sites"]) - return_sites = return_sites[ - np.isin( - return_sites, - list( - self.module.model.list_obs_plate_vars()[ - "sites" - ].keys() - ), + for tensor_dict in track( + train_dl, + style="tqdm", + description="Sampling local variables, batch: ", + ): + args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) + args = [a.to(device) for a in args] + kwargs = {k: v.to(device) for k, v in kwargs.items()} + self.to_device(device) + + if i == 0: + sample_kwargs_obs_plate = sample_kwargs.copy() + sample_kwargs_obs_plate[ + "return_sites" + ] = self._check_obs_plate_return_sites(sample_kwargs) + sample_kwargs_obs_plate["show_progress"] = False + obs_plate = self._find_plate_dimension(args, kwargs) + obs_plate_dim = list(obs_plate.values())[0] + samples = self._get_posterior_samples( + args, kwargs, **sample_kwargs_obs_plate + ) + else: + samples_ = self._get_posterior_samples( + args, kwargs, **sample_kwargs_obs_plate + ) + + samples = { + k: np.array( + [ + np.concatenate( + [samples[k][j], samples_[k][j]], + axis=obs_plate_dim, ) + for j in range( + len(samples[k]) + ) # for each sample (in 0 dimension ] - if len(return_sites) == 0: - sample_kwargs_obs_plate["return_sites"] = [return_sites] - else: - sample_kwargs_obs_plate["return_sites"] = list(return_sites) - else: - sample_kwargs_obs_plate["return_sites"] = list( - self.module.model.list_obs_plate_vars()["sites"].keys() - ) - sample_kwargs_obs_plate["show_progress"] = False - samples = self._get_posterior_samples( - args, kwargs, **sample_kwargs_obs_plate - ) - - # find plate dimension - trace = poutine.trace(self.module.model).get_trace(*args, **kwargs) - obs_plate = { - name: site["cond_indep_stack"][0].dim - for name, site in trace.nodes.items() - if site["type"] == "sample" - if any( - f.name == self.module.model.list_obs_plate_vars()["name"] - for f in site["cond_indep_stack"] - ) - } - - else: - args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) - args = [a.to(device) for a in args] - kwargs = {k: v.to(device) for k, v in kwargs.items()} - self.to_device(device) - - samples_ = self._get_posterior_samples( - args, kwargs, **sample_kwargs_obs_plate ) - samples = { - k: np.array( - [ - np.concatenate( - [samples[k][i], samples_[k][i]], - axis=list(obs_plate.values())[0], - ) - for i in range(len(samples[k])) - ] - ) - for k in samples.keys() - } - i += 1 + for k in samples.keys() # for each variable + } + i += 1 # sample global parameters - i = 0 for tensor_dict in train_dl: - if i == 0: - args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) - args = [a.to(device) for a in args] - kwargs = {k: v.to(device) for k, v in kwargs.items()} - self.to_device(device) - - global_samples = self._get_posterior_samples( - args, kwargs, **sample_kwargs - ) - global_samples = { - k: global_samples[k] - for k in global_samples.keys() - if k not in self.module.model.list_obs_plate_vars()["sites"] - } - i += 1 + args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) + args = [a.to(device) for a in args] + kwargs = {k: v.to(device) for k, v in kwargs.items()} + self.to_device(device) + + global_samples = self._get_posterior_samples(args, kwargs, **sample_kwargs) + global_samples = { + k: v + for k, v in global_samples.items() + if k + not in list(self.module.model.list_obs_plate_vars()["sites"].keys()) + } + break for k in global_samples.keys(): samples[k] = global_samples[k] diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index 0212754401..9e30315589 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -25,11 +25,11 @@ class BayesianRegressionPyroModel(PyroModule): - def __init__(self, in_features, out_features, n_obs, per_cell_weight=False): + def __init__(self, in_features, out_features, per_cell_weight=False): super().__init__() self.in_features = in_features self.out_features = out_features - self.n_obs = n_obs + self.n_obs = None self.per_cell_weight = per_cell_weight @@ -50,6 +50,8 @@ def __init__(self, in_features, out_features, n_obs, per_cell_weight=False): ) def create_plates(self, x, y, ind_x): + """Function for creating plates is needed when using AutoGuides + and should have the same call signature as model""" return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=ind_x) def list_obs_plate_vars(self): @@ -142,7 +144,6 @@ def __init__( self.module = BayesianRegressionModule( in_features=adata.shape[1], out_features=1, - n_obs=adata.n_obs, per_cell_weight=per_cell_weight, ) self._model_summary_string = "BayesianRegressionModel" @@ -162,9 +163,7 @@ def test_pyro_bayesian_regression(save_path): ) train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() - model = BayesianRegressionModule( - in_features=adata.shape[1], out_features=1, n_obs=adata.n_obs - ) + model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1) plan = PyroTrainingPlan(model, n_obs=len(train_dl.indices)) trainer = Trainer( gpus=use_gpu, @@ -195,9 +194,7 @@ def test_pyro_bayesian_regression(save_path): torch.save(model.state_dict(), model_save_path) pyro.clear_param_store() - new_model = BayesianRegressionModule( - in_features=adata.shape[1], out_features=1, n_obs=adata.n_obs - ) + new_model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1) # run model one step to get autoguide params try: new_model.load_state_dict(torch.load(model_save_path)) @@ -234,9 +231,7 @@ def test_pyro_bayesian_regression_jit(): ) train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() - model = BayesianRegressionModule( - in_features=adata.shape[1], out_features=1, n_obs=adata.n_obs - ) + model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1) train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) plan = PyroTrainingPlan( model, loss_fn=pyro.infer.JitTrace_ELBO(), n_obs=len(train_dl.indices) From 9ff126eb9a7c34b5932a28474514ca4c5f30274e Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Fri, 14 May 2021 11:44:48 +0100 Subject: [PATCH 05/50] using TrainRunner --- scvi/model/base/_pyromixin.py | 77 +++++++++++++++++++++++++++++++++-- scvi/train/_trainingplans.py | 15 ++++++- tests/models/test_pyro.py | 26 ++++++++---- 3 files changed, 105 insertions(+), 13 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index d37eb9f96c..a3767691be 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -9,9 +9,9 @@ from pyro.infer import SVI from pytorch_lightning.callbacks import Callback -from scvi.dataloaders import AnnDataLoader +from scvi.dataloaders import AnnDataLoader, DataSplitter from scvi.model._utils import parse_use_gpu_arg -from scvi.train import PyroTrainingPlan, Trainer +from scvi.train import PyroTrainingPlan, Trainer, TrainRunner from scvi.utils import track logger = logging.getLogger(__name__) @@ -46,6 +46,75 @@ class PyroSviTrainMixin: - training models using minibatches and using full data (copies data to GPU only once). """ + def train( + self, + max_epochs: Optional[int] = None, + use_gpu: Optional[Union[str, int, bool]] = None, + train_size: float = 0.9, + validation_size: Optional[float] = None, + batch_size: int = 128, + early_stopping: bool = False, + plan_kwargs: Optional[dict] = None, + **trainer_kwargs, + ): + """ + Train the model. + + Parameters + ---------- + max_epochs + Number of passes through the dataset. If `None`, defaults to + `np.min([round((20000 / n_cells) * 400), 400])` + use_gpu + Use default GPU if available (if None or True), or index of GPU to use (if int), + or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). + train_size + Size of training set in the range [0.0, 1.0]. + validation_size + Size of the test set. If `None`, defaults to 1 - `train_size`. If + `train_size + validation_size < 1`, the remaining cells belong to a test set. + batch_size + Minibatch size to use during training. + early_stopping + Perform early stopping. Additional arguments can be passed in `**kwargs`. + See :class:`~scvi.train.Trainer` for further options. + plan_kwargs + Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to + `train()` will overwrite values present in `plan_kwargs`, when appropriate. + **trainer_kwargs + Other keyword args for :class:`~scvi.train.Trainer`. + """ + if max_epochs is None: + n_obs = self.adata.n_obs + max_epochs = np.min([round((20000 / n_obs) * 1000), 1000]) + + plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() + + data_splitter = DataSplitter( + self.adata, + train_size=train_size, + validation_size=validation_size, + batch_size=batch_size, + use_gpu=use_gpu, + ) + training_plan = PyroTrainingPlan( + pyro_module=self.module, n_obs=len(data_splitter.train_idx), **plan_kwargs + ) + + es = "early_stopping" + trainer_kwargs[es] = ( + early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es] + ) + runner = TrainRunner( + self, + training_plan=training_plan, + data_splitter=data_splitter, + max_epochs=max_epochs, + use_gpu=use_gpu, + **trainer_kwargs, + ) + return runner() + def _train_full_data( self, max_epochs: Optional[int] = None, @@ -198,7 +267,7 @@ def _train_minibatch( max_epochs=max_epochs, max_steps=max_steps, # callbacks=[PyroJitGuideWarmup(train_dl)], - **trainer_kwargs + **trainer_kwargs, ) trainer.fit(plan, train_dl) self.module.to(device) @@ -235,7 +304,7 @@ def optim_param(module_name, param_name): return optim_param - def train(self, **kwargs): + def train_v2(self, **kwargs): """ Train the model. diff --git a/scvi/train/_trainingplans.py b/scvi/train/_trainingplans.py index 3a0b614012..ecdd00eec3 100644 --- a/scvi/train/_trainingplans.py +++ b/scvi/train/_trainingplans.py @@ -597,6 +597,7 @@ def __init__( loss_fn: Optional[pyro.infer.ELBO] = None, optim: Optional[pyro.optim.PyroOptim] = None, n_obs: Optional[int] = None, + optim_kwargs: Optional[dict] = None, ): super().__init__() self.module = pyro_module @@ -609,8 +610,20 @@ def __init__( if hasattr(self.module.guide, "n_obs"): setattr(self.module.guide, "n_obs", n_obs) + optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else dict() + optim_kwargs["lr"] = ( + optim_kwargs["lr"] if "lr" in list(optim_kwargs.keys()) else 1e-3 + ) + optim_kwargs["clip_norm"] = ( + optim_kwargs["clip_norm"] + if "clip_norm" in list(optim_kwargs.keys()) + else 200 + ) + self.loss_fn = pyro.infer.Trace_ELBO() if loss_fn is None else loss_fn - self.optim = pyro.optim.Adam({"lr": 1e-3}) if optim is None else optim + self.optim = ( + pyro.optim.ClippedAdam(optim_args=optim_kwargs) if optim is None else optim + ) self.automatic_optimization = False self.pyro_guide = self.module.guide diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index 9e30315589..f42a57a150 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -55,14 +55,19 @@ def create_plates(self, x, y, ind_x): return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=ind_x) def list_obs_plate_vars(self): - """Create a dictionary with the name of observation/minibatch plate, - indexes of model args to provide to encoder, - variable names that belong to the observation plate - and the number of dimensions in non-plate axis of each variable""" + """Create a dictionary with: + 1. "name" - the name of observation/minibatch plate; + 2. "in" - indexes of model args to provide to encoder network when using amortised inference; + 3. "sites" - dictionary with + keys - names of variables that belong to the observation plate (used to recognise + and merge posterior samples for minibatch variables) + values - the dimensions in non-plate axis of each variable (used to construct output + layer of encoder network when using amortised inference) + """ return { "name": "obs_plate", - "in": [0], # index for expression data + "in": [0], # model args index for expression data "sites": {"per_cell_weights": 1}, } @@ -263,7 +268,7 @@ def test_pyro_bayesian_train_sample_mixin(): use_gpu = torch.cuda.is_available() adata = synthetic_iid() mod = BayesianRegressionModel(adata, batch_size=128) - mod.train(max_epochs=2, lr=0.01, use_gpu=use_gpu) + mod.train(max_epochs=2, plan_kwargs={"optim_kwargs": {"lr": 0.01}}, use_gpu=use_gpu) # 100 features, 1 for sigma, 1 for bias # assert list(mod.module.guide.parameters())[0].shape[0] == 102 @@ -278,7 +283,7 @@ def test_pyro_bayesian_train_sample_mixin_full_data(): use_gpu = torch.cuda.is_available() adata = synthetic_iid() mod = BayesianRegressionModel(adata, batch_size=None) - mod.train(max_epochs=2, lr=0.01, use_gpu=use_gpu) + mod.train(max_epochs=2, plan_kwargs={"optim_kwargs": {"lr": 0.01}}, use_gpu=use_gpu) # 100 features, 1 for sigma, 1 for bias # assert list(mod.module.guide.parameters())[0].shape[0] == 102 @@ -293,7 +298,12 @@ def test_pyro_bayesian_train_sample_mixin_with_local(): use_gpu = torch.cuda.is_available() adata = synthetic_iid() mod = BayesianRegressionModel(adata, batch_size=128, per_cell_weight=True) - mod.train(max_epochs=2, lr=0.01, use_gpu=use_gpu) + mod.train( + max_epochs=2, + plan_kwargs={"optim_kwargs": {"lr": 0.01}}, + train_size=1, # does not work when there is a validation set. + use_gpu=use_gpu, + ) # 100 features, 1 for sigma, 1 for bias # assert list(mod.module.guide.parameters())[0].shape[0] == 102 From 700aff23bde9ed9b69e3e6c9a931c14340bde635 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Fri, 14 May 2021 13:02:56 +0100 Subject: [PATCH 06/50] fixed parameter shape test --- tests/models/test_pyro.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index f42a57a150..1b2313f337 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -246,8 +246,11 @@ def test_pyro_bayesian_regression_jit(): ) trainer.fit(plan, train_dl) - # 100 features, 1 for sigma, 1 for bias - # assert list(model.guide.parameters())[0].shape[0] == 102 + # 100 features + assert list(model.guide.state_dict()["locs.linear.weight_unconstrained"].shape) == [ + 1, + 100, + ] if use_gpu == 1: model.cuda() @@ -270,8 +273,10 @@ def test_pyro_bayesian_train_sample_mixin(): mod = BayesianRegressionModel(adata, batch_size=128) mod.train(max_epochs=2, plan_kwargs={"optim_kwargs": {"lr": 0.01}}, use_gpu=use_gpu) - # 100 features, 1 for sigma, 1 for bias - # assert list(mod.module.guide.parameters())[0].shape[0] == 102 + # 100 features + assert list( + mod.module.guide.state_dict()["locs.linear.weight_unconstrained"].shape + ) == [1, 100] # test posterior sampling samples = mod.sample_posterior(num_samples=10, use_gpu=use_gpu, return_samples=True) @@ -285,8 +290,10 @@ def test_pyro_bayesian_train_sample_mixin_full_data(): mod = BayesianRegressionModel(adata, batch_size=None) mod.train(max_epochs=2, plan_kwargs={"optim_kwargs": {"lr": 0.01}}, use_gpu=use_gpu) - # 100 features, 1 for sigma, 1 for bias - # assert list(mod.module.guide.parameters())[0].shape[0] == 102 + # 100 features + assert list( + mod.module.guide.state_dict()["locs.linear.weight_unconstrained"].shape + ) == [1, 100] # test posterior sampling samples = mod.sample_posterior(num_samples=10, use_gpu=use_gpu, return_samples=True) @@ -305,8 +312,10 @@ def test_pyro_bayesian_train_sample_mixin_with_local(): use_gpu=use_gpu, ) - # 100 features, 1 for sigma, 1 for bias - # assert list(mod.module.guide.parameters())[0].shape[0] == 102 + # 100 + assert list( + mod.module.guide.state_dict()["locs.linear.weight_unconstrained"].shape + ) == [1, 100] # test posterior sampling samples = mod.sample_posterior(num_samples=10, use_gpu=use_gpu, return_samples=True) From 0fd40c417cb2d8795c030966e1fa3ea1927fd170 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Sat, 15 May 2021 23:27:28 +0100 Subject: [PATCH 07/50] deleted full data specific training code --- scvi/model/base/_pyromixin.py | 239 +--------------------------------- tests/models/test_pyro.py | 14 +- 2 files changed, 3 insertions(+), 250 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index a3767691be..32e7423cb2 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -2,16 +2,13 @@ from typing import Optional, Union import numpy as np -import pandas as pd -import pyro import torch from pyro import poutine -from pyro.infer import SVI from pytorch_lightning.callbacks import Callback from scvi.dataloaders import AnnDataLoader, DataSplitter from scvi.model._utils import parse_use_gpu_arg -from scvi.train import PyroTrainingPlan, Trainer, TrainRunner +from scvi.train import PyroTrainingPlan, TrainRunner from scvi.utils import track logger = logging.getLogger(__name__) @@ -97,9 +94,7 @@ def train( batch_size=batch_size, use_gpu=use_gpu, ) - training_plan = PyroTrainingPlan( - pyro_module=self.module, n_obs=len(data_splitter.train_idx), **plan_kwargs - ) + training_plan = PyroTrainingPlan(pyro_module=self.module, **plan_kwargs) es = "early_stopping" trainer_kwargs[es] = ( @@ -115,236 +110,6 @@ def train( ) return runner() - def _train_full_data( - self, - max_epochs: Optional[int] = None, - use_gpu: bool = False, - plan_kwargs: Optional[dict] = None, - lr: float = 0.01, - optim_kwargs: Optional[dict] = None, - continue_training: bool = True, - ): - """ - Private method for training using full data. - - Parameters - ---------- - max_epochs - Number of training epochs / iterations - use_gpu - Bool, use gpu? - plan_kwargs - Training plan arguments such as optim and loss_fn - continue_training - When the model is already trained, should calling .train() continue training? (False = restart training) - - Returns - ------- - ELBO history in self.module.history_ - - """ - - args, kwargs = self.module.model._get_fn_args_full_data(self.adata) - gpus, device = parse_use_gpu_arg(use_gpu) - optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else dict() - - args = [a.to(device) for a in args] - kwargs = {k: v.to(device) for k, v in kwargs.items()} - self.to_device(device) - - if hasattr(self.module.model, "n_obs"): - setattr(self.module.model, "n_obs", self.adata.n_obs) - if hasattr(self.module.guide, "n_obs"): - setattr(self.module.guide, "n_obs", self.adata.n_obs) - - if not continue_training or not self.is_trained_: - # models share param store, make sure it is cleared before training - pyro.clear_param_store() - # initialise guide params (warmup) - self.module.guide(*args, **kwargs) - - svi = SVI( - self.module.model, - self.module.guide, - # select optimiser, optionally choosing different lr for autoencoding guide - pyro.optim.ClippedAdam(self._optim_param(lr, **optim_kwargs)), - loss=plan_kwargs["loss_fn"], - ) - - iter_iterator = track( - range(max_epochs), - style="tqdm", - ) - hist = [] - for it in iter_iterator: - - loss = svi.step(*args, **kwargs) - iter_iterator.set_description( - "Epoch " + "{:d}".format(it) + ", -ELBO: " + "{:.4e}".format(loss) - ) - hist.append(loss) - - if it % 500 == 0: - torch.cuda.empty_cache() - - if continue_training and self.is_trained_: - # add ELBO listory - hist = self.module.history_ + hist - self.module.history_ = hist - self.module.is_trained_ = True - self.history_ = hist - self.is_trained_ = True - - def _train_minibatch( - self, - max_epochs: Optional[int] = None, - max_steps: Optional[int] = None, - use_gpu: bool = False, - plan_kwargs: Optional[dict] = None, - trainer_kwargs: Optional[dict] = None, - lr: float = 0.01, - optim_kwargs: Optional[dict] = None, - early_stopping: bool = False, - continue_training: bool = True, - ): - """ - Private method for training using minibatches (scVI interface and pytorch lightning). - - Parameters - ---------- - max_epochs - Number of training epochs / iterations - max_steps - Number of training steps - use_gpu - Bool, use gpu? - plan_kwargs - Training plan arguments such as optim and loss_fn - trainer_kwargs - Arguments for scvi.train.Trainer. - optim_kwargs - optimiser creation arguments to such as autoencoding_lr, clip_norm, module_names - early_stopping - Bool, use early stopping? (not tested) - continue_training - When the model is already trained, should calling .train() continue training? (False = restart training) - - Returns - ------- - ELBO history in self.module.history_ - - """ - - if not continue_training or not self.is_trained_: - # models share param store, make sure it is cleared before training - pyro.clear_param_store() - - gpus, device = parse_use_gpu_arg(use_gpu) - if max_epochs is None: - n_obs = self.adata.n_obs - max_epochs = np.min([round((20000 / n_obs) * 400), 400]) - - plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() - plan_kwargs["n_obs"] = self.adata.n_obs - trainer_kwargs = trainer_kwargs if isinstance(trainer_kwargs, dict) else dict() - optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else dict() - - batch_size = self.batch_size - # select optimiser, optionally choosing different lr for different parameters - plan_kwargs["optim"] = pyro.optim.ClippedAdam( - self._optim_param(lr, **optim_kwargs) - ) - - # create data loader for training - train_dl = AnnDataLoader(self.adata, shuffle=True, batch_size=batch_size) - plan = PyroTrainingPlan(self.module, **plan_kwargs) - es = "early_stopping" - trainer_kwargs[es] = ( - early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es] - ) - trainer = Trainer( - gpus=gpus, - max_epochs=max_epochs, - max_steps=max_steps, - # callbacks=[PyroJitGuideWarmup(train_dl)], - **trainer_kwargs, - ) - trainer.fit(plan, train_dl) - self.module.to(device) - - try: - if continue_training and self.is_trained_: - # add ELBO listory - index = range( - len(self.module.history_), - len(self.module.history_) - + len(trainer.logger.history["train_loss_epoch"]), - ) - trainer.logger.history["train_loss_epoch"].index = index - self.module.history_ = pd.concat( - [self.module.history_, trainer.logger.history["train_loss_epoch"]] - ) - else: - self.module.history_ = trainer.logger.history["train_loss_epoch"] - self.history_ = self.module.history_ - except AttributeError: - self.history_ = None - - self.module.is_trained_ = True - self.is_trained_ = True - - def _optim_param(self, lr: float = 0.01, clip_norm: float = 200): - # create function which fetches different lr for different parameters - def optim_param(module_name, param_name): - return { - "lr": lr, - # limit the gradient step from becoming too large - "clip_norm": clip_norm, - } - - return optim_param - - def train_v2(self, **kwargs): - """ - Train the model. - - Parameters - ---------- - max_epochs - Number of training epochs / iterations - max_steps - Number of training steps - use_gpu - Bool, use gpu? - lr - Learning rate. - autoencoding_lr - Optional, a separate learning rate for encoder network. - clip_norm - Gradient clipping norm (useful for preventing exploding gradients, - which can lead to impossible values and NaN loss). - trainer_kwargs - Training plan arguments for scvi.train.PyroTrainingPlan (Excluding optim and loss_fn) - early_stopping - Bool, use early stopping? (not tested) - - Returns - ------- - ELBO history in self.module.history_ - - """ - - plan_kwargs = {"loss_fn": pyro.infer.Trace_ELBO()} - - batch_size = self.batch_size - - if batch_size is None: - # train using full data (faster for small datasets) - self._train_full_data(plan_kwargs=plan_kwargs, **kwargs) - else: - # standard training using minibatches - self._train_minibatch(plan_kwargs=plan_kwargs, **kwargs) - class PyroSampleMixin: """ diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index 7487086cdc..fd843409b3 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -8,11 +8,9 @@ from anndata import AnnData from pyro.infer.autoguide import AutoNormal, init_to_mean from pyro.nn import PyroModule, PyroSample -from scipy.sparse import issparse from scvi import _CONSTANTS from scvi.data import register_tensor_from_anndata, synthetic_iid -from scvi.data._anndata import get_from_registry from scvi.dataloaders import AnnDataLoader from scvi.model.base import ( BaseModelClass, @@ -78,16 +76,6 @@ def _get_fn_args_from_batch(tensor_dict): ind_x = tensor_dict["ind_x"].long().squeeze() return (x, y, ind_x), {} - @staticmethod - def _get_fn_args_full_data(adata): - x = get_from_registry(adata, _CONSTANTS.X_KEY) - if issparse(x): - x = np.asarray(x.toarray()) - x = torch.tensor(x.astype("float32")) - ind_x = torch.tensor(get_from_registry(adata, "ind_x")) - y = torch.tensor(get_from_registry(adata, _CONSTANTS.LABELS_KEY)) - return (x, y, ind_x), {} - def forward(self, x, y, ind_x): obs_plate = self.create_plates(x, y, ind_x) @@ -168,7 +156,7 @@ def test_pyro_bayesian_regression(save_path): ) train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() - model = BayesianRegressionModule(adata.shape[1], 1) + model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1) plan = PyroTrainingPlan(model) plan.n_obs_training = len(train_dl.indices) trainer = Trainer( From d649947876494456dcf59e4ac9d1e6676a442c41 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Sat, 15 May 2021 23:45:55 +0100 Subject: [PATCH 08/50] posterior sampling with batch_size=None --- scvi/model/base/_pyromixin.py | 82 ++++++++++++++--------------------- tests/models/test_pyro.py | 4 ++ 2 files changed, 36 insertions(+), 50 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 32e7423cb2..a10ef7c909 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -87,13 +87,23 @@ def train( plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() - data_splitter = DataSplitter( - self.adata, - train_size=train_size, - validation_size=validation_size, - batch_size=batch_size, - use_gpu=use_gpu, - ) + if batch_size is None: + # use data splitter which moves data to GPU once + data_splitter = DataSplitter( + self.adata, + train_size=train_size, + validation_size=validation_size, + batch_size=batch_size, + use_gpu=use_gpu, + ) + else: + data_splitter = DataSplitter( + self.adata, + train_size=train_size, + validation_size=validation_size, + batch_size=batch_size, + use_gpu=use_gpu, + ) training_plan = PyroTrainingPlan(pyro_module=self.module, **plan_kwargs) es = "early_stopping" @@ -218,33 +228,6 @@ def _get_posterior_samples( return {k: np.array(v) for k, v in samples.items()} - def _posterior_samples_full_data(self, use_gpu: bool = True, **sample_kwargs): - """ - Generate samples from posterior distribution using all data - - Parameters - ---------- - sample_kwargs - arguments to _get_posterior_samples - - Returns - ------- - dictionary {variable_name: [array with samples in 0 dimension]} - - """ - - self.module.eval() - gpus, device = parse_use_gpu_arg(use_gpu) - - args, kwargs = self.module.model._get_fn_args_full_data(self.adata) - args = [a.to(device) for a in args] - kwargs = {k: v.to(device) for k, v in kwargs.items()} - self.to_device(device) - - samples = self._get_posterior_samples(args, kwargs, **sample_kwargs) - - return samples - def _check_obs_plate_return_sites(self, sample_kwargs): # check whether any variable requested in return_sites are in obs_plate obs_plate_sites = list(self.module.model.list_obs_plate_vars()["sites"].keys()) @@ -296,9 +279,11 @@ def _posterior_samples_minibatch(self, use_gpu: bool = True, **sample_kwargs): gpus, device = parse_use_gpu_arg(use_gpu) - self.module.eval() - - train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=self.batch_size) + if self.batch_size is None: + batch_size = self.adata.n_obs + else: + batch_size = self.batch_size + train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size) # sample local parameters i = 0 for tensor_dict in track( @@ -376,6 +361,7 @@ def sample_posterior( ): """ Generate samples from posterior distribution for each parameter + and compute mean, 5%/95% quantiles, standard deviation. Parameters ---------- @@ -394,9 +380,13 @@ def sample_posterior( Returns ------- - Posterior distribution samples, a dictionary for each of (mean, 5% quantile, SD, optionally all samples), - containing dictionaries for each variable with numpy arrays. - Dictionary of all samples contains samples for each variable as numpy arrays of shape ``(n_samples, ...)`` + Posterior distribution samples, a dictionary with elements as follows, + containing dictionaries of numpy arrays for each variable: + post_sample_means - mean of the distribution for each variable; + post_sample_q05 - 5% quantile; + post_sample_q95 - 95% quantile; + post_sample_sds - standard deviation + posterior_samples - samples for each variable as numpy arrays of shape `(n_samples, ...)` (Optional) """ @@ -404,16 +394,8 @@ def sample_posterior( sample_kwargs["num_samples"] = num_samples sample_kwargs["return_sites"] = return_sites - if self.batch_size is None: - # sample using full data - samples = self._posterior_samples_full_data( - use_gpu=use_gpu, **sample_kwargs - ) - else: - # sample using minibatches - samples = self._posterior_samples_minibatch( - use_gpu=use_gpu, **sample_kwargs - ) + # sample using minibatches (if full data, data is moved to GPU only once anyway) + samples = self._posterior_samples_minibatch(use_gpu=use_gpu, **sample_kwargs) param_names = list(samples.keys()) results = dict() diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index fd843409b3..a7798145fc 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -240,6 +240,10 @@ def test_pyro_bayesian_regression_jit(): 1, 100, ] + # 1 bias + assert list(model.guide.state_dict()["locs.linear.bias_unconstrained"].shape) == [ + 1, + ] if use_gpu == 1: model.cuda() From c3da184b74d17e5437f9436d441024e86fe0a8c5 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Sat, 15 May 2021 23:48:27 +0100 Subject: [PATCH 09/50] get rid of for loop in sampling global param --- scvi/model/base/_pyromixin.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index a10ef7c909..2eeb7e40e5 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -329,20 +329,18 @@ def _posterior_samples_minibatch(self, use_gpu: bool = True, **sample_kwargs): i += 1 # sample global parameters - for tensor_dict in train_dl: - args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) - args = [a.to(device) for a in args] - kwargs = {k: v.to(device) for k, v in kwargs.items()} - self.to_device(device) - - global_samples = self._get_posterior_samples(args, kwargs, **sample_kwargs) - global_samples = { - k: v - for k, v in global_samples.items() - if k - not in list(self.module.model.list_obs_plate_vars()["sites"].keys()) - } - break + tensor_dict = next(iter(train_dl)) + args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) + args = [a.to(device) for a in args] + kwargs = {k: v.to(device) for k, v in kwargs.items()} + self.to_device(device) + + global_samples = self._get_posterior_samples(args, kwargs, **sample_kwargs) + global_samples = { + k: v + for k, v in global_samples.items() + if k not in list(self.module.model.list_obs_plate_vars()["sites"].keys()) + } for k in global_samples.keys(): samples[k] = global_samples[k] From 264d6831e0b0201c29d3a4fa6acad6c1f9747aa8 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Sun, 16 May 2021 19:02:15 +0100 Subject: [PATCH 10/50] device-backed dataloader, providing batch size to sample_posterior, test for full data local var --- scvi/model/base/_pyromixin.py | 17 +++++---- tests/models/test_pyro.py | 65 +++++++++++++++++++++++++++++------ 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 2eeb7e40e5..1bb98d81c3 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -6,7 +6,7 @@ from pyro import poutine from pytorch_lightning.callbacks import Callback -from scvi.dataloaders import AnnDataLoader, DataSplitter +from scvi.dataloaders import AnnDataLoader, DataSplitter, DeviceBackedDataSplitter from scvi.model._utils import parse_use_gpu_arg from scvi.train import PyroTrainingPlan, TrainRunner from scvi.utils import track @@ -89,7 +89,7 @@ def train( if batch_size is None: # use data splitter which moves data to GPU once - data_splitter = DataSplitter( + data_splitter = DeviceBackedDataSplitter( self.adata, train_size=train_size, validation_size=validation_size, @@ -258,7 +258,9 @@ def _find_plate_dimension(self, args, kwargs): } return obs_plate - def _posterior_samples_minibatch(self, use_gpu: bool = True, **sample_kwargs): + def _posterior_samples_minibatch( + self, use_gpu: bool = True, batch_size: int = 128, **sample_kwargs + ): """ Generate samples of the posterior distribution of each parameter, separating local (minibatch) variables and global variables, which is necessary when performing minibatch inference. @@ -279,10 +281,8 @@ def _posterior_samples_minibatch(self, use_gpu: bool = True, **sample_kwargs): gpus, device = parse_use_gpu_arg(use_gpu) - if self.batch_size is None: + if batch_size is None: batch_size = self.adata.n_obs - else: - batch_size = self.batch_size train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size) # sample local parameters i = 0 @@ -354,6 +354,7 @@ def sample_posterior( num_samples: int = 1000, return_sites: Optional[list] = None, use_gpu: bool = False, + batch_size: int = 128, sample_kwargs=None, return_samples: bool = False, ): @@ -393,7 +394,9 @@ def sample_posterior( sample_kwargs["return_sites"] = return_sites # sample using minibatches (if full data, data is moved to GPU only once anyway) - samples = self._posterior_samples_minibatch(use_gpu=use_gpu, **sample_kwargs) + samples = self._posterior_samples_minibatch( + use_gpu=use_gpu, batch_size=batch_size, **sample_kwargs + ) param_names = list(samples.keys()) results = dict() diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index a7798145fc..c697878571 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -119,7 +119,6 @@ class BayesianRegressionModel(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass def __init__( self, adata: AnnData, - batch_size=None, per_cell_weight=False, ): # add index for each cell (provided to pyro plate for correct minibatching) @@ -133,7 +132,6 @@ def __init__( super().__init__(adata) - self.batch_size = batch_size self.module = BayesianRegressionModule( in_features=adata.shape[1], out_features=1, @@ -263,8 +261,13 @@ def test_pyro_bayesian_regression_jit(): def test_pyro_bayesian_train_sample_mixin(): use_gpu = torch.cuda.is_available() adata = synthetic_iid() - mod = BayesianRegressionModel(adata, batch_size=128) - mod.train(max_epochs=2, plan_kwargs={"optim_kwargs": {"lr": 0.01}}, use_gpu=use_gpu) + mod = BayesianRegressionModel(adata) + mod.train( + max_epochs=2, + batch_size=128, + plan_kwargs={"optim_kwargs": {"lr": 0.01}}, + use_gpu=use_gpu, + ) # 100 features assert list( @@ -272,7 +275,9 @@ def test_pyro_bayesian_train_sample_mixin(): ) == [1, 100] # test posterior sampling - samples = mod.sample_posterior(num_samples=10, use_gpu=use_gpu, return_samples=True) + samples = mod.sample_posterior( + num_samples=10, use_gpu=use_gpu, batch_size=128, return_samples=True + ) assert len(samples["posterior_samples"]["sigma"]) == 10 @@ -280,8 +285,13 @@ def test_pyro_bayesian_train_sample_mixin(): def test_pyro_bayesian_train_sample_mixin_full_data(): use_gpu = torch.cuda.is_available() adata = synthetic_iid() - mod = BayesianRegressionModel(adata, batch_size=None) - mod.train(max_epochs=2, plan_kwargs={"optim_kwargs": {"lr": 0.01}}, use_gpu=use_gpu) + mod = BayesianRegressionModel(adata) + mod.train( + max_epochs=2, + batch_size=None, + plan_kwargs={"optim_kwargs": {"lr": 0.01}}, + use_gpu=use_gpu, + ) # 100 features assert list( @@ -289,7 +299,9 @@ def test_pyro_bayesian_train_sample_mixin_full_data(): ) == [1, 100] # test posterior sampling - samples = mod.sample_posterior(num_samples=10, use_gpu=use_gpu, return_samples=True) + samples = mod.sample_posterior( + num_samples=10, use_gpu=use_gpu, batch_size=None, return_samples=True + ) assert len(samples["posterior_samples"]["sigma"]) == 10 @@ -297,9 +309,10 @@ def test_pyro_bayesian_train_sample_mixin_full_data(): def test_pyro_bayesian_train_sample_mixin_with_local(): use_gpu = torch.cuda.is_available() adata = synthetic_iid() - mod = BayesianRegressionModel(adata, batch_size=128, per_cell_weight=True) + mod = BayesianRegressionModel(adata, per_cell_weight=True) mod.train( max_epochs=2, + batch_size=128, plan_kwargs={"optim_kwargs": {"lr": 0.01}}, train_size=1, # does not work when there is a validation set. use_gpu=use_gpu, @@ -311,7 +324,39 @@ def test_pyro_bayesian_train_sample_mixin_with_local(): ) == [1, 100] # test posterior sampling - samples = mod.sample_posterior(num_samples=10, use_gpu=use_gpu, return_samples=True) + samples = mod.sample_posterior( + num_samples=10, use_gpu=use_gpu, batch_size=128, return_samples=True + ) + + assert len(samples["posterior_samples"]["sigma"]) == 10 + assert samples["posterior_samples"]["per_cell_weights"].shape == ( + 10, + adata.n_obs, + 1, + ) + + +def test_pyro_bayesian_train_sample_mixin_with_local_full_data(): + use_gpu = torch.cuda.is_available() + adata = synthetic_iid() + mod = BayesianRegressionModel(adata, per_cell_weight=True) + mod.train( + max_epochs=2, + batch_size=None, + plan_kwargs={"optim_kwargs": {"lr": 0.01}}, + train_size=1, # does not work when there is a validation set. + use_gpu=use_gpu, + ) + + # 100 + assert list( + mod.module.guide.state_dict()["locs.linear.weight_unconstrained"].shape + ) == [1, 100] + + # test posterior sampling + samples = mod.sample_posterior( + num_samples=10, use_gpu=use_gpu, batch_size=None, return_samples=True + ) assert len(samples["posterior_samples"]["sigma"]) == 10 assert samples["posterior_samples"]["per_cell_weights"].shape == ( From f071f2f20472930ee400bc383d81e0914bc315f0 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 17 May 2021 01:58:12 +0100 Subject: [PATCH 11/50] clear pyro param store when initialising new model --- tests/models/test_pyro.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index c697878571..b8047d6edd 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn from anndata import AnnData +from pyro import clear_param_store from pyro.infer.autoguide import AutoNormal, init_to_mean from pyro.nn import PyroModule, PyroSample @@ -121,6 +122,9 @@ def __init__( adata: AnnData, per_cell_weight=False, ): + # in case any other model was created before that shares the same parameter names. + clear_param_store() + # add index for each cell (provided to pyro plate for correct minibatching) adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") register_tensor_from_anndata( From 7b34c51c8d7e7e86e1ebdf2992d42c662cbae836 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 17 May 2021 12:36:06 +0100 Subject: [PATCH 12/50] Update scvi/model/base/_pyromixin.py Co-authored-by: Adam Gayoso --- scvi/model/base/_pyromixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 1bb98d81c3..a69ce3c43e 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -71,7 +71,7 @@ def train( Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size - Minibatch size to use during training. + Minibatch size to use during training. If `None`, no minibatching occurs and all data is copied to device (e.g., GPU). early_stopping Perform early stopping. Additional arguments can be passed in `**kwargs`. See :class:`~scvi.train.Trainer` for further options. From b020ff5fef9248df17f07d7ed1a6f794d8f7673e Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 17 May 2021 12:39:07 +0100 Subject: [PATCH 13/50] Update scvi/model/base/_pyromixin.py Co-authored-by: Adam Gayoso --- scvi/model/base/_pyromixin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index a69ce3c43e..c3289d8206 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -24,6 +24,7 @@ def __init__(self, train_dl) -> None: def on_train_start(self, trainer, pl_module): """ Way to warmup Pyro Guide in an automated way. + Also device agnostic. """ From cb3cae91fc6bab18b67b193520decbc8f526aeb6 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Wed, 19 May 2021 18:56:08 +0100 Subject: [PATCH 14/50] guide callback --- scvi/model/base/_pyromixin.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 1bb98d81c3..83fbe4ae95 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -110,6 +110,13 @@ def train( trainer_kwargs[es] = ( early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es] ) + + if "callbacks" not in trainer_kwargs.keys(): + trainer_kwargs["callbacks"] = [] + trainer_kwargs["callbacks"].append( + PyroJitGuideWarmup(data_splitter.train_dataloader()) + ) + runner = TrainRunner( self, training_plan=training_plan, From 97288cb2885685a228536703ea561e83604a69b2 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Wed, 19 May 2021 19:03:10 +0100 Subject: [PATCH 15/50] commented guide callback --- scvi/model/base/_pyromixin.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index ffa69a510c..ef9aaee122 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -112,11 +112,11 @@ def train( early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es] ) - if "callbacks" not in trainer_kwargs.keys(): - trainer_kwargs["callbacks"] = [] - trainer_kwargs["callbacks"].append( - PyroJitGuideWarmup(data_splitter.train_dataloader()) - ) + # if "callbacks" not in trainer_kwargs.keys(): + # trainer_kwargs["callbacks"] = [] + # trainer_kwargs["callbacks"].append( + # PyroJitGuideWarmup(data_splitter.train_dataloader()) + # ) runner = TrainRunner( self, From 08dbe1f8c27073838343c79e70635ac3048f3f38 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Sat, 22 May 2021 12:09:54 +0100 Subject: [PATCH 16/50] list_obs_plate_vars as module property, incl default --- scvi/model/base/_pyromixin.py | 37 +++++++++++++++++++++----------- scvi/module/base/_base_module.py | 15 +++++++++++++ tests/models/test_pyro.py | 9 ++++++-- 3 files changed, 47 insertions(+), 14 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index ef9aaee122..f2b7857a99 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -236,9 +236,15 @@ def _get_posterior_samples( return {k: np.array(v) for k, v in samples.items()} + def _get_obs_plate_sites(self): + # get a list of observation/minibatch plate sites + return list(self.module.list_obs_plate_vars["sites"].keys()) + def _check_obs_plate_return_sites(self, sample_kwargs): + + obs_plate_sites = self._get_obs_plate_sites() + # check whether any variable requested in return_sites are in obs_plate - obs_plate_sites = list(self.module.model.list_obs_plate_vars()["sites"].keys()) if ("return_sites" in sample_kwargs.keys()) and ( sample_kwargs["return_sites"] is not None ): @@ -253,18 +259,18 @@ def _check_obs_plate_return_sites(self, sample_kwargs): def _find_plate_dimension(self, args, kwargs): + plate_name = self.module.list_obs_plate_vars["name"] + # find plate dimension trace = poutine.trace(self.module.model).get_trace(*args, **kwargs) obs_plate = { name: site["cond_indep_stack"][0].dim for name, site in trace.nodes.items() if site["type"] == "sample" - if any( - f.name == self.module.model.list_obs_plate_vars()["name"] - for f in site["cond_indep_stack"] - ) + if any(f.name == plate_name for f in site["cond_indep_stack"]) } - return obs_plate + + return list(obs_plate.values())[0] def _posterior_samples_minibatch( self, use_gpu: bool = True, batch_size: int = 128, **sample_kwargs @@ -273,9 +279,6 @@ def _posterior_samples_minibatch( Generate samples of the posterior distribution of each parameter, separating local (minibatch) variables and global variables, which is necessary when performing minibatch inference. - Note for developers: requires model class method which lists observation/minibatch plate - variables (self.module.model.list_obs_plate_vars()). - Parameters ---------- use_gpu @@ -285,8 +288,16 @@ def _posterior_samples_minibatch( ------- dictionary {variable_name: [array with samples in 0 dimension]} + Notes + ----- + Note for developers: requires scVI module property (a dictionary, self.module.list_obs_plate_vars) which lists + observation/minibatch plate name and variables. + This dictionary can be returned by model class method self.module.model.list_obs_plate_vars(). + """ + samples = dict() + gpus, device = parse_use_gpu_arg(use_gpu) if batch_size is None: @@ -310,8 +321,10 @@ def _posterior_samples_minibatch( "return_sites" ] = self._check_obs_plate_return_sites(sample_kwargs) sample_kwargs_obs_plate["show_progress"] = False - obs_plate = self._find_plate_dimension(args, kwargs) - obs_plate_dim = list(obs_plate.values())[0] + if len(sample_kwargs_obs_plate["return_sites"]) == 0: + # if no local variables - don't sample + break + obs_plate_dim = self._find_plate_dimension(args, kwargs) samples = self._get_posterior_samples( args, kwargs, **sample_kwargs_obs_plate ) @@ -347,7 +360,7 @@ def _posterior_samples_minibatch( global_samples = { k: v for k, v in global_samples.items() - if k not in list(self.module.model.list_obs_plate_vars()["sites"].keys()) + if k not in self._get_obs_plate_sites() } for k in global_samples.keys(): diff --git a/scvi/module/base/_base_module.py b/scvi/module/base/_base_module.py index 837ebc6780..556de46df8 100644 --- a/scvi/module/base/_base_module.py +++ b/scvi/module/base/_base_module.py @@ -261,6 +261,21 @@ def model(self): def guide(self): pass + @property + def list_obs_plate_vars(self): + """Model annotation for minibatch training with pyro plate. + + A dictionary with: + 1. "name" - the name of observation/minibatch plate; + 2. "in" - indexes of model args to provide to encoder network when using amortised inference; + 3. "sites" - dictionary with + keys - names of variables that belong to the observation plate (used to recognise + and merge posterior samples for minibatch variables) + values - the dimensions in non-plate axis of each variable (used to construct output + layer of encoder network when using amortised inference) + """ + return {"name": "", "in": [], "sites": {}} + def create_predictive( self, model: Optional[Callable] = None, diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index b8047d6edd..2eeb1b6d2a 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -54,7 +54,9 @@ def create_plates(self, x, y, ind_x): return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=ind_x) def list_obs_plate_vars(self): - """Create a dictionary with: + """Model annotation for minibatch training with pyro plate. + + A dictionary with: 1. "name" - the name of observation/minibatch plate; 2. "in" - indexes of model args to provide to encoder network when using amortised inference; 3. "sites" - dictionary with @@ -63,7 +65,6 @@ def list_obs_plate_vars(self): values - the dimensions in non-plate axis of each variable (used to construct output layer of encoder network when using amortised inference) """ - return { "name": "obs_plate", "in": [0], # model args index for expression data @@ -115,6 +116,10 @@ def model(self): def guide(self): return self._guide + @property + def list_obs_plate_vars(self): + return self.model.list_obs_plate_vars() + class BayesianRegressionModel(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass): def __init__( From b53b75b42940c8937821a5f6cbba06d8caa509ff Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Sat, 22 May 2021 12:20:09 +0100 Subject: [PATCH 17/50] updated docs --- scvi/model/base/_pyromixin.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index f2b7857a99..bcfc2be2b8 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -276,6 +276,8 @@ def _posterior_samples_minibatch( self, use_gpu: bool = True, batch_size: int = 128, **sample_kwargs ): """ + Generate samples of the posterior distribution in minibatches + Generate samples of the posterior distribution of each parameter, separating local (minibatch) variables and global variables, which is necessary when performing minibatch inference. @@ -290,9 +292,11 @@ def _posterior_samples_minibatch( Notes ----- - Note for developers: requires scVI module property (a dictionary, self.module.list_obs_plate_vars) which lists - observation/minibatch plate name and variables. - This dictionary can be returned by model class method self.module.model.list_obs_plate_vars(). + Note for developers: requires scVI module property (a dictionary, self.module.list_obs_plate_vars) + which lists observation/minibatch plate name and variables. + See PyroBaseModuleClass.list_obs_plate_vars for details. + This dictionary can be returned by model class method self.module.model.list_obs_plate_vars() + to keep all model-specific variables in one place. """ From 0cb9e6038f1739072a1a0a06a8f73c26d3a43a04 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Sat, 22 May 2021 14:38:13 +0100 Subject: [PATCH 18/50] guess obs plate sites without list_obs_plate_vars --- scvi/model/base/_pyromixin.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index bcfc2be2b8..26460fb007 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -236,13 +236,7 @@ def _get_posterior_samples( return {k: np.array(v) for k, v in samples.items()} - def _get_obs_plate_sites(self): - # get a list of observation/minibatch plate sites - return list(self.module.list_obs_plate_vars["sites"].keys()) - - def _check_obs_plate_return_sites(self, sample_kwargs): - - obs_plate_sites = self._get_obs_plate_sites() + def _get_obs_plate_return_sites(self, sample_kwargs, obs_plate_sites): # check whether any variable requested in return_sites are in obs_plate if ("return_sites" in sample_kwargs.keys()) and ( @@ -257,7 +251,7 @@ def _check_obs_plate_return_sites(self, sample_kwargs): else: return obs_plate_sites - def _find_plate_dimension(self, args, kwargs): + def _get_obs_plate_sites(self, args, kwargs): plate_name = self.module.list_obs_plate_vars["name"] @@ -270,7 +264,7 @@ def _find_plate_dimension(self, args, kwargs): if any(f.name == plate_name for f in site["cond_indep_stack"]) } - return list(obs_plate.values())[0] + return obs_plate def _posterior_samples_minibatch( self, use_gpu: bool = True, batch_size: int = 128, **sample_kwargs @@ -320,15 +314,20 @@ def _posterior_samples_minibatch( self.to_device(device) if i == 0: + obs_plate_sites = self._get_obs_plate_sites(args, kwargs) + if len(obs_plate_sites) == 0: + # if no local variables - don't sample + break + obs_plate_dim = list(obs_plate_sites.values())[0] + sample_kwargs_obs_plate = sample_kwargs.copy() sample_kwargs_obs_plate[ "return_sites" - ] = self._check_obs_plate_return_sites(sample_kwargs) + ] = self._get_obs_plate_return_sites( + sample_kwargs, list(obs_plate_sites.keys()) + ) sample_kwargs_obs_plate["show_progress"] = False - if len(sample_kwargs_obs_plate["return_sites"]) == 0: - # if no local variables - don't sample - break - obs_plate_dim = self._find_plate_dimension(args, kwargs) + samples = self._get_posterior_samples( args, kwargs, **sample_kwargs_obs_plate ) @@ -364,7 +363,7 @@ def _posterior_samples_minibatch( global_samples = { k: v for k, v in global_samples.items() - if k not in self._get_obs_plate_sites() + if k not in list(obs_plate_sites.keys()) } for k in global_samples.keys(): From ed796c00687e45d0e2f0f8412ab19fddfdd54ee7 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 24 May 2021 11:53:13 +0100 Subject: [PATCH 19/50] fixed guide warmup dataloader callback; --- scvi/model/base/_pyromixin.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 26460fb007..87cd5babfd 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -112,11 +112,12 @@ def train( early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es] ) - # if "callbacks" not in trainer_kwargs.keys(): - # trainer_kwargs["callbacks"] = [] - # trainer_kwargs["callbacks"].append( - # PyroJitGuideWarmup(data_splitter.train_dataloader()) - # ) + data_splitter.setup() + if "callbacks" not in trainer_kwargs.keys(): + trainer_kwargs["callbacks"] = [] + trainer_kwargs["callbacks"].append( + PyroJitGuideWarmup(data_splitter.train_dataloader()) + ) runner = TrainRunner( self, From deddaab26be95417ba2b2c84c22f2dc1c3752ea8 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 24 May 2021 11:57:32 +0100 Subject: [PATCH 20/50] remove redundant dataloader in test --- tests/models/test_pyro.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index 2eeb1b6d2a..a9021de291 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -234,7 +234,6 @@ def test_pyro_bayesian_regression_jit(): train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) pyro.clear_param_store() model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1) - train_dl = AnnDataLoader(adata, shuffle=True, batch_size=128) plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO()) plan.n_obs_training = len(train_dl.indices) trainer = Trainer( From 0621d8cfd33e63ff2db0c96877620dae8edc90e0 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 24 May 2021 12:57:30 +0100 Subject: [PATCH 21/50] gpu None default --- scvi/model/base/_pyromixin.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 87cd5babfd..26bd4b5493 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -268,7 +268,7 @@ def _get_obs_plate_sites(self, args, kwargs): return obs_plate def _posterior_samples_minibatch( - self, use_gpu: bool = True, batch_size: int = 128, **sample_kwargs + self, use_gpu: bool = None, batch_size: int = 128, **sample_kwargs ): """ Generate samples of the posterior distribution in minibatches @@ -289,7 +289,7 @@ def _posterior_samples_minibatch( ----- Note for developers: requires scVI module property (a dictionary, self.module.list_obs_plate_vars) which lists observation/minibatch plate name and variables. - See PyroBaseModuleClass.list_obs_plate_vars for details. + See PyroBaseModuleClass.list_obs_plate_vars for details of the variables it should contain. This dictionary can be returned by model class method self.module.model.list_obs_plate_vars() to keep all model-specific variables in one place. @@ -378,7 +378,7 @@ def sample_posterior( self, num_samples: int = 1000, return_sites: Optional[list] = None, - use_gpu: bool = False, + use_gpu: bool = None, batch_size: int = 128, sample_kwargs=None, return_samples: bool = False, From 40146c624ba751b76b3ff876a8b2a1b362bca02f Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 24 May 2021 13:01:51 +0100 Subject: [PATCH 22/50] consistent use GPU docs --- scvi/model/base/_pyromixin.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 26bd4b5493..00daa1d1f1 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -279,7 +279,8 @@ def _posterior_samples_minibatch( Parameters ---------- use_gpu - Bool, use gpu? + Load model on default GPU if available (if None or True), + or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). Returns ------- @@ -394,7 +395,8 @@ def sample_posterior( return_sites get samples for pyro model variable, default is all variables, otherwise list variable names). use_gpu - Use gpu? + Load model on default GPU if available (if None or True), + or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). sample_kwargs dictionary with arguments to _get_posterior_samples (see below): return_observed From 151ccea477cafe843619c4e37e234fad170897ec Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 24 May 2021 13:04:29 +0100 Subject: [PATCH 23/50] do not move data to device again for sampling global variables --- scvi/model/base/_pyromixin.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 00daa1d1f1..1108bface8 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -355,12 +355,6 @@ def _posterior_samples_minibatch( i += 1 # sample global parameters - tensor_dict = next(iter(train_dl)) - args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) - args = [a.to(device) for a in args] - kwargs = {k: v.to(device) for k, v in kwargs.items()} - self.to_device(device) - global_samples = self._get_posterior_samples(args, kwargs, **sample_kwargs) global_samples = { k: v From 27b3ecf2137a637228edd4816b706a642a6498b5 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 24 May 2021 13:22:16 +0100 Subject: [PATCH 24/50] updated docs and lr interface for default optimiser --- scvi/model/base/_pyromixin.py | 6 ++++++ scvi/train/_trainingplans.py | 15 ++++++--------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 1108bface8..38996231d0 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -52,6 +52,7 @@ def train( validation_size: Optional[float] = None, batch_size: int = 128, early_stopping: bool = False, + lr: Optional[float] = None, plan_kwargs: Optional[dict] = None, **trainer_kwargs, ): @@ -76,6 +77,9 @@ def train( early_stopping Perform early stopping. Additional arguments can be passed in `**kwargs`. See :class:`~scvi.train.Trainer` for further options. + lr + Optimiser learning rate (default optimiser is Pyro ClippedAdam). + Specifying optimiser via plan_kwargs overrides this choice of lr. plan_kwargs Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to `train()` will overwrite values present in `plan_kwargs`, when appropriate. @@ -87,6 +91,8 @@ def train( max_epochs = np.min([round((20000 / n_obs) * 1000), 1000]) plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() + if lr is not None and "optim" not in plan_kwargs.keys(): + plan_kwargs.update({"optim_kwargs": {"lr": lr}}) if batch_size is None: # use data splitter which moves data to GPU once diff --git a/scvi/train/_trainingplans.py b/scvi/train/_trainingplans.py index 46702313a3..5aee185769 100644 --- a/scvi/train/_trainingplans.py +++ b/scvi/train/_trainingplans.py @@ -593,6 +593,8 @@ class PyroTrainingPlan(pl.LightningModule): optim A Pyro optimizer, e.g., :class:`~pyro.optim.Adam`. If `None`, defaults to Adam optimizer with a learning rate of `1e-3`. + optim_kwargs + Keyword arguments for default optimiser (pyro.optim.ClippedAdam). """ def __init__( @@ -600,7 +602,6 @@ def __init__( pyro_module: PyroBaseModuleClass, loss_fn: Optional[pyro.infer.ELBO] = None, optim: Optional[pyro.optim.PyroOptim] = None, - n_obs: Optional[int] = None, optim_kwargs: Optional[dict] = None, ): super().__init__() @@ -608,14 +609,10 @@ def __init__( self._n_obs_training = None optim_kwargs = optim_kwargs if isinstance(optim_kwargs, dict) else dict() - optim_kwargs["lr"] = ( - optim_kwargs["lr"] if "lr" in list(optim_kwargs.keys()) else 1e-3 - ) - optim_kwargs["clip_norm"] = ( - optim_kwargs["clip_norm"] - if "clip_norm" in list(optim_kwargs.keys()) - else 200 - ) + if "lr" not in optim_kwargs.keys(): + optim_kwargs.update({"lr": 1e-3}) + if "clip_norm" not in optim_kwargs.keys(): + optim_kwargs.update({"clip_norm": 200}) self.loss_fn = pyro.infer.Trace_ELBO() if loss_fn is None else loss_fn self.optim = ( From 38ba6ba940ff96b7d260abec627c509b86230740 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 24 May 2021 13:31:22 +0100 Subject: [PATCH 25/50] updated docs --- scvi/model/base/_pyromixin.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 38996231d0..e323bd0fe5 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -39,9 +39,9 @@ def on_train_start(self, trainer, pl_module): class PyroSviTrainMixin: """ - This mixin class provides methods for: + Mixin class for training Pyro models - - training models using minibatches and using full data (copies data to GPU only once). + Training using minibatches and using full data (copies data to GPU only once). """ def train( @@ -138,9 +138,9 @@ def train( class PyroSampleMixin: """ - This mixin class provides methods for: + Mixin class for generating samples from posterior distribution - - generating samples from posterior distribution using minibatches and full data + Works using both minibatches and full data. """ @torch.no_grad() @@ -151,7 +151,8 @@ def _get_one_posterior_sample( return_sites: Optional[list] = None, sample_observed: bool = False, ): - """Get one sample from posterior distribution. + """ + Get one sample from posterior distribution. Parameters ---------- @@ -203,7 +204,7 @@ def _get_posterior_samples( show_progress: bool = True, ): """ - Get many samples from posterior distribution. + Get many (num_samples=N) samples from posterior distribution. Parameters ---------- @@ -244,6 +245,9 @@ def _get_posterior_samples( return {k: np.array(v) for k, v in samples.items()} def _get_obs_plate_return_sites(self, sample_kwargs, obs_plate_sites): + """ + Check sample_kwargs["return_sites"] for overlap with observation/minibatch plate sites + """ # check whether any variable requested in return_sites are in obs_plate if ("return_sites" in sample_kwargs.keys()) and ( @@ -259,6 +263,23 @@ def _get_obs_plate_return_sites(self, sample_kwargs, obs_plate_sites): return obs_plate_sites def _get_obs_plate_sites(self, args, kwargs): + """ + Automatically guess which model sites belong to observation/minibatch plate + + This function requires minibatch plate name specified in `self.module.list_obs_plate_vars["name"]`. + + Parameters + ---------- + args + Arguments to the model. + kwargs + Keyword arguments to the model. + + Returns + ------- + Dictionary with keys corresponding to site names and values to plate dimension. + + """ plate_name = self.module.list_obs_plate_vars["name"] @@ -385,6 +406,8 @@ def sample_posterior( return_samples: bool = False, ): """ + Summarise posterior distribution + Generate samples from posterior distribution for each parameter and compute mean, 5%/95% quantiles, standard deviation. From 28e03fca85f61c69801fa1647c2733672a6b0e35 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 24 May 2021 13:32:24 +0100 Subject: [PATCH 26/50] updated docs --- scvi/model/base/_pyromixin.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index e323bd0fe5..473cb2592d 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -39,7 +39,7 @@ def on_train_start(self, trainer, pl_module): class PyroSviTrainMixin: """ - Mixin class for training Pyro models + Mixin class for training Pyro models. Training using minibatches and using full data (copies data to GPU only once). """ @@ -138,7 +138,7 @@ def train( class PyroSampleMixin: """ - Mixin class for generating samples from posterior distribution + Mixin class for generating samples from posterior distribution. Works using both minibatches and full data. """ @@ -246,7 +246,7 @@ def _get_posterior_samples( def _get_obs_plate_return_sites(self, sample_kwargs, obs_plate_sites): """ - Check sample_kwargs["return_sites"] for overlap with observation/minibatch plate sites + Check sample_kwargs["return_sites"] for overlap with observation/minibatch plate sites. """ # check whether any variable requested in return_sites are in obs_plate @@ -264,7 +264,7 @@ def _get_obs_plate_return_sites(self, sample_kwargs, obs_plate_sites): def _get_obs_plate_sites(self, args, kwargs): """ - Automatically guess which model sites belong to observation/minibatch plate + Automatically guess which model sites belong to observation/minibatch plate. This function requires minibatch plate name specified in `self.module.list_obs_plate_vars["name"]`. @@ -298,7 +298,7 @@ def _posterior_samples_minibatch( self, use_gpu: bool = None, batch_size: int = 128, **sample_kwargs ): """ - Generate samples of the posterior distribution in minibatches + Generate samples of the posterior distribution in minibatches. Generate samples of the posterior distribution of each parameter, separating local (minibatch) variables and global variables, which is necessary when performing minibatch inference. @@ -406,7 +406,7 @@ def sample_posterior( return_samples: bool = False, ): """ - Summarise posterior distribution + Summarise posterior distribution. Generate samples from posterior distribution for each parameter and compute mean, 5%/95% quantiles, standard deviation. From 7dd6d80407333f576deeedfb708a5e9cd39ec329 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 24 May 2021 13:46:26 +0100 Subject: [PATCH 27/50] more flexible sample_posterior with summary_fun dict --- scvi/model/base/_pyromixin.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 473cb2592d..bf76d0bba9 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -404,6 +404,7 @@ def sample_posterior( batch_size: int = 128, sample_kwargs=None, return_samples: bool = False, + summary_fun: Optional[dict] = None, ): """ Summarise posterior distribution. @@ -415,7 +416,7 @@ def sample_posterior( ---------- num_samples number of posterior samples to generate. - return_sites + return_site get samples for pyro model variable, default is all variables, otherwise list variable names). use_gpu Load model on default GPU if available (if None or True), @@ -426,6 +427,9 @@ def sample_posterior( return observed sites/variables? return_samples return samples in addition to sample mean, 5%/95% quantile and SD? + summary_fun + a dict in the form {"means": np.mean, "std": np.std} which specifies posterior distribution + summaries to compute and which names to use. Returns ------- @@ -453,13 +457,16 @@ def sample_posterior( if return_samples: results["posterior_samples"] = samples - results["post_sample_means"] = {v: samples[v].mean(axis=0) for v in param_names} - results["post_sample_q05"] = { - v: np.quantile(samples[v], 0.05, axis=0) for v in param_names - } - results["post_sample_q95"] = { - v: np.quantile(samples[v], 0.95, axis=0) for v in param_names - } - results["post_sample_sds"] = {v: samples[v].std(axis=0) for v in param_names} + if summary_fun is None: + summary_fun = { + "means": np.mean, + "sts": np.std, + "q05": lambda x, axis: np.quantile(x, 0.05, axis=axis), + "q95": lambda x, axis: np.quantile(x, 0.95, axis=axis), + } + for k, fun in summary_fun.items(): + results[f"post_sample_{k}"] = { + v: fun(samples[v], axis=0) for v in param_names + } return results From a2a3376703baf7bd85b3d0740da600ff033827a9 Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Mon, 24 May 2021 10:24:53 -0700 Subject: [PATCH 28/50] add to docs --- docs/api/developer.rst | 3 +++ scvi/model/base/_pyromixin.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/docs/api/developer.rst b/docs/api/developer.rst index 442f6de269..e7a8a5f9f7 100644 --- a/docs/api/developer.rst +++ b/docs/api/developer.rst @@ -64,6 +64,9 @@ These classes should be used to construct user-facing model classes. model.base.RNASeqMixin model.base.ArchesMixin model.base.UnsupervisedTrainingMixin + model.base.PyroSviTrainMixin + model.base.PyroSampleMixin + model.base.PyroJitGuideWarmup Module ------ diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index bf76d0bba9..b68c18a3c8 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -18,6 +18,12 @@ class PyroJitGuideWarmup(Callback): def __init__(self, train_dl) -> None: + """ + A callback to warmup a Pyro guide. + + This helps initialize all the relevant parameters by running + one minibatch through the Pyro model. + """ super().__init__() self.dl = train_dl From e830682b5f7ac7b7eaa14281769a654f84fab446 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Fri, 28 May 2021 14:04:47 +0100 Subject: [PATCH 29/50] Update scvi/model/base/_pyromixin.py Co-authored-by: Adam Gayoso --- scvi/model/base/_pyromixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index b68c18a3c8..a48b18e3c7 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -323,7 +323,7 @@ def _posterior_samples_minibatch( ----- Note for developers: requires scVI module property (a dictionary, self.module.list_obs_plate_vars) which lists observation/minibatch plate name and variables. - See PyroBaseModuleClass.list_obs_plate_vars for details of the variables it should contain. + See :func:`~scvi.module.base.PyroBaseModuleClass.list_obs_plate_vars` for details of the variables it should contain. This dictionary can be returned by model class method self.module.model.list_obs_plate_vars() to keep all model-specific variables in one place. From b58901614c5c2b08aff88c46f32f4781dbca1bd2 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Fri, 28 May 2021 14:05:53 +0100 Subject: [PATCH 30/50] Update scvi/train/_trainingplans.py Co-authored-by: Adam Gayoso --- scvi/train/_trainingplans.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/train/_trainingplans.py b/scvi/train/_trainingplans.py index 5aee185769..ee8a91167a 100644 --- a/scvi/train/_trainingplans.py +++ b/scvi/train/_trainingplans.py @@ -594,7 +594,7 @@ class PyroTrainingPlan(pl.LightningModule): A Pyro optimizer, e.g., :class:`~pyro.optim.Adam`. If `None`, defaults to Adam optimizer with a learning rate of `1e-3`. optim_kwargs - Keyword arguments for default optimiser (pyro.optim.ClippedAdam). + Keyword arguments for default optimiser :class:`pyro.optim.ClippedAdam`. """ def __init__( From bfa3685b9ba5def6b3e913f35f0a18964efdc23b Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Fri, 28 May 2021 14:06:36 +0100 Subject: [PATCH 31/50] Update scvi/model/base/_pyromixin.py Co-authored-by: Adam Gayoso --- scvi/model/base/_pyromixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index a48b18e3c7..f4cab4e351 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -321,7 +321,7 @@ def _posterior_samples_minibatch( Notes ----- - Note for developers: requires scVI module property (a dictionary, self.module.list_obs_plate_vars) + Note for developers: requires overwritten :func:`~scvi.module.base.PyroBaseModuleClass.list_obs_plate_vars` method. which lists observation/minibatch plate name and variables. See :func:`~scvi.module.base.PyroBaseModuleClass.list_obs_plate_vars` for details of the variables it should contain. This dictionary can be returned by model class method self.module.model.list_obs_plate_vars() From 580992d8dcfaa06bffca66e797147c0b8f981591 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Fri, 28 May 2021 14:06:54 +0100 Subject: [PATCH 32/50] Update scvi/model/base/_pyromixin.py Co-authored-by: Adam Gayoso --- scvi/model/base/_pyromixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index f4cab4e351..09bdf3d7f9 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -324,7 +324,7 @@ def _posterior_samples_minibatch( Note for developers: requires overwritten :func:`~scvi.module.base.PyroBaseModuleClass.list_obs_plate_vars` method. which lists observation/minibatch plate name and variables. See :func:`~scvi.module.base.PyroBaseModuleClass.list_obs_plate_vars` for details of the variables it should contain. - This dictionary can be returned by model class method self.module.model.list_obs_plate_vars() + This dictionary can be returned by model class method `self.module.model.list_obs_plate_vars()` to keep all model-specific variables in one place. """ From dccfd74c512f1fadad9efb49a8f769faf6203789 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 31 May 2021 19:52:05 +0100 Subject: [PATCH 33/50] default batch size and doc edits --- scvi/model/base/_pyromixin.py | 10 +++------- scvi/train/_trainingplans.py | 6 +++--- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 09bdf3d7f9..74e3ca1dc6 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -6,6 +6,7 @@ from pyro import poutine from pytorch_lightning.callbacks import Callback +from scvi import settings from scvi.dataloaders import AnnDataLoader, DataSplitter, DeviceBackedDataSplitter from scvi.model._utils import parse_use_gpu_arg from scvi.train import PyroTrainingPlan, TrainRunner @@ -170,7 +171,6 @@ def _get_one_posterior_sample( Returns ------- Dictionary with a sample for each variable - """ guide_trace = poutine.trace(self.module.guide).get_trace(*args, **kwargs) @@ -225,7 +225,6 @@ def _get_posterior_samples( ------- Dictionary with array of samples for each variable dictionary {variable_name: [array with samples in 0 dimension]} - """ samples = self._get_one_posterior_sample( @@ -284,7 +283,6 @@ def _get_obs_plate_sites(self, args, kwargs): Returns ------- Dictionary with keys corresponding to site names and values to plate dimension. - """ plate_name = self.module.list_obs_plate_vars["name"] @@ -326,15 +324,14 @@ def _posterior_samples_minibatch( See :func:`~scvi.module.base.PyroBaseModuleClass.list_obs_plate_vars` for details of the variables it should contain. This dictionary can be returned by model class method `self.module.model.list_obs_plate_vars()` to keep all model-specific variables in one place. - """ samples = dict() gpus, device = parse_use_gpu_arg(use_gpu) - if batch_size is None: - batch_size = self.adata.n_obs + batch_size = batch_size if batch_size is not None else settings.batch_size + train_dl = AnnDataLoader(self.adata, shuffle=False, batch_size=batch_size) # sample local parameters i = 0 @@ -446,7 +443,6 @@ def sample_posterior( post_sample_q95 - 95% quantile; post_sample_sds - standard deviation posterior_samples - samples for each variable as numpy arrays of shape `(n_samples, ...)` (Optional) - """ sample_kwargs = sample_kwargs if isinstance(sample_kwargs, dict) else dict() diff --git a/scvi/train/_trainingplans.py b/scvi/train/_trainingplans.py index ee8a91167a..a1cf5e5c8c 100644 --- a/scvi/train/_trainingplans.py +++ b/scvi/train/_trainingplans.py @@ -591,10 +591,10 @@ class PyroTrainingPlan(pl.LightningModule): A Pyro loss. Should be a subclass of :class:`~pyro.infer.ELBO`. If `None`, defaults to :class:`~pyro.infer.Trace_ELBO`. optim - A Pyro optimizer, e.g., :class:`~pyro.optim.Adam`. If `None`, - defaults to Adam optimizer with a learning rate of `1e-3`. + A Pyro optimizer instance, e.g., :class:`~pyro.optim.Adam`. If `None`, + defaults to ClippedAdam optimizer with a learning rate of `1e-3` and `clip_norm` of `200`. optim_kwargs - Keyword arguments for default optimiser :class:`pyro.optim.ClippedAdam`. + Keyword arguments for **default** optimiser :class:`pyro.optim.ClippedAdam`. """ def __init__( From b9b57b1c277139ecdb8bd7a4eb49719d5c436b76 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 31 May 2021 20:03:59 +0100 Subject: [PATCH 34/50] updated tests and docs --- scvi/model/base/_pyromixin.py | 12 ++++++++---- tests/models/test_pyro.py | 16 ++++++++-------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 74e3ca1dc6..578e5cef08 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -217,7 +217,7 @@ def _get_posterior_samples( args arguments to model and guide kwargs - arguments to model and guide + keyword arguments to model and guide show_progress show progress bar @@ -299,7 +299,7 @@ def _get_obs_plate_sites(self, args, kwargs): return obs_plate def _posterior_samples_minibatch( - self, use_gpu: bool = None, batch_size: int = 128, **sample_kwargs + self, use_gpu: bool = None, batch_size: Optional[int] = None, **sample_kwargs ): """ Generate samples of the posterior distribution in minibatches. @@ -312,6 +312,8 @@ def _posterior_samples_minibatch( use_gpu Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). + batch_size + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. Returns ------- @@ -404,7 +406,7 @@ def sample_posterior( num_samples: int = 1000, return_sites: Optional[list] = None, use_gpu: bool = None, - batch_size: int = 128, + batch_size: Optional[int] = None, sample_kwargs=None, return_samples: bool = False, summary_fun: Optional[dict] = None, @@ -424,6 +426,8 @@ def sample_posterior( use_gpu Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). + batch_size + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. sample_kwargs dictionary with arguments to _get_posterior_samples (see below): return_observed @@ -437,7 +441,7 @@ def sample_posterior( Returns ------- Posterior distribution samples, a dictionary with elements as follows, - containing dictionaries of numpy arrays for each variable: + containing dictionaries of numpy arrays for each variable: post_sample_means - mean of the distribution for each variable; post_sample_q05 - 5% quantile; post_sample_q95 - 95% quantile; diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index a9021de291..cfa7066961 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -273,7 +273,7 @@ def test_pyro_bayesian_train_sample_mixin(): mod.train( max_epochs=2, batch_size=128, - plan_kwargs={"optim_kwargs": {"lr": 0.01}}, + lr=0.01, use_gpu=use_gpu, ) @@ -284,7 +284,7 @@ def test_pyro_bayesian_train_sample_mixin(): # test posterior sampling samples = mod.sample_posterior( - num_samples=10, use_gpu=use_gpu, batch_size=128, return_samples=True + num_samples=10, use_gpu=use_gpu, batch_size=None, return_samples=True ) assert len(samples["posterior_samples"]["sigma"]) == 10 @@ -297,7 +297,7 @@ def test_pyro_bayesian_train_sample_mixin_full_data(): mod.train( max_epochs=2, batch_size=None, - plan_kwargs={"optim_kwargs": {"lr": 0.01}}, + lr=0.01, use_gpu=use_gpu, ) @@ -308,7 +308,7 @@ def test_pyro_bayesian_train_sample_mixin_full_data(): # test posterior sampling samples = mod.sample_posterior( - num_samples=10, use_gpu=use_gpu, batch_size=None, return_samples=True + num_samples=10, use_gpu=use_gpu, batch_size=adata.n_obs, return_samples=True ) assert len(samples["posterior_samples"]["sigma"]) == 10 @@ -321,7 +321,7 @@ def test_pyro_bayesian_train_sample_mixin_with_local(): mod.train( max_epochs=2, batch_size=128, - plan_kwargs={"optim_kwargs": {"lr": 0.01}}, + lr=0.01, train_size=1, # does not work when there is a validation set. use_gpu=use_gpu, ) @@ -333,7 +333,7 @@ def test_pyro_bayesian_train_sample_mixin_with_local(): # test posterior sampling samples = mod.sample_posterior( - num_samples=10, use_gpu=use_gpu, batch_size=128, return_samples=True + num_samples=10, use_gpu=use_gpu, batch_size=None, return_samples=True ) assert len(samples["posterior_samples"]["sigma"]) == 10 @@ -351,7 +351,7 @@ def test_pyro_bayesian_train_sample_mixin_with_local_full_data(): mod.train( max_epochs=2, batch_size=None, - plan_kwargs={"optim_kwargs": {"lr": 0.01}}, + lr=0.01, train_size=1, # does not work when there is a validation set. use_gpu=use_gpu, ) @@ -363,7 +363,7 @@ def test_pyro_bayesian_train_sample_mixin_with_local_full_data(): # test posterior sampling samples = mod.sample_posterior( - num_samples=10, use_gpu=use_gpu, batch_size=None, return_samples=True + num_samples=10, use_gpu=use_gpu, batch_size=adata.n_obs, return_samples=True ) assert len(samples["posterior_samples"]["sigma"]) == 10 From ee65f57bf4354a232493f0b7e1892a92970483fd Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 31 May 2021 20:13:56 +0100 Subject: [PATCH 35/50] exposed sample_kwargs --- scvi/model/base/_pyromixin.py | 36 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 578e5cef08..99700dd765 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -249,16 +249,14 @@ def _get_posterior_samples( return {k: np.array(v) for k, v in samples.items()} - def _get_obs_plate_return_sites(self, sample_kwargs, obs_plate_sites): + def _get_obs_plate_return_sites(self, return_sites, obs_plate_sites): """ - Check sample_kwargs["return_sites"] for overlap with observation/minibatch plate sites. + Check return_sites for overlap with observation/minibatch plate sites. """ # check whether any variable requested in return_sites are in obs_plate - if ("return_sites" in sample_kwargs.keys()) and ( - sample_kwargs["return_sites"] is not None - ): - return_sites = np.array(sample_kwargs["return_sites"]) + if return_sites is not None: + return_sites = np.array(return_sites) return_sites = return_sites[np.isin(return_sites, obs_plate_sites)] if len(return_sites) == 0: return [return_sites] @@ -358,7 +356,7 @@ def _posterior_samples_minibatch( sample_kwargs_obs_plate[ "return_sites" ] = self._get_obs_plate_return_sites( - sample_kwargs, list(obs_plate_sites.keys()) + sample_kwargs["return_sites"], list(obs_plate_sites.keys()) ) sample_kwargs_obs_plate["show_progress"] = False @@ -407,7 +405,7 @@ def sample_posterior( return_sites: Optional[list] = None, use_gpu: bool = None, batch_size: Optional[int] = None, - sample_kwargs=None, + return_observed: bool = False, return_samples: bool = False, summary_fun: Optional[dict] = None, ): @@ -420,20 +418,18 @@ def sample_posterior( Parameters ---------- num_samples - number of posterior samples to generate. + Number of posterior samples to generate. return_site - get samples for pyro model variable, default is all variables, otherwise list variable names). + List of variables for which to generate posterior samples, defaults to all variables. use_gpu Load model on default GPU if available (if None or True), or index of GPU to use (if int), or name of GPU (if str), or use CPU (if False). batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - sample_kwargs - dictionary with arguments to _get_posterior_samples (see below): - return_observed - return observed sites/variables? + return_observed + Return observed sites/variables? Observed count matrix can be very large so not returned by default. return_samples - return samples in addition to sample mean, 5%/95% quantile and SD? + Return samples in addition to sample mean, 5%/95% quantile and SD? summary_fun a dict in the form {"means": np.mean, "std": np.std} which specifies posterior distribution summaries to compute and which names to use. @@ -449,13 +445,13 @@ def sample_posterior( posterior_samples - samples for each variable as numpy arrays of shape `(n_samples, ...)` (Optional) """ - sample_kwargs = sample_kwargs if isinstance(sample_kwargs, dict) else dict() - sample_kwargs["num_samples"] = num_samples - sample_kwargs["return_sites"] = return_sites - # sample using minibatches (if full data, data is moved to GPU only once anyway) samples = self._posterior_samples_minibatch( - use_gpu=use_gpu, batch_size=batch_size, **sample_kwargs + use_gpu=use_gpu, + batch_size=batch_size, + num_samples=num_samples, + return_sites=return_sites, + return_observed=return_observed, ) param_names = list(samples.keys()) From 46c430f1a983f664650c57a24cef4a9a55f36318 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 31 May 2021 20:34:54 +0100 Subject: [PATCH 36/50] updated docs --- scvi/model/base/_pyromixin.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 99700dd765..d1be9b476d 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Union +from typing import Callable, Dict, Optional, Union import numpy as np import torch @@ -316,14 +316,6 @@ def _posterior_samples_minibatch( Returns ------- dictionary {variable_name: [array with samples in 0 dimension]} - - Notes - ----- - Note for developers: requires overwritten :func:`~scvi.module.base.PyroBaseModuleClass.list_obs_plate_vars` method. - which lists observation/minibatch plate name and variables. - See :func:`~scvi.module.base.PyroBaseModuleClass.list_obs_plate_vars` for details of the variables it should contain. - This dictionary can be returned by model class method `self.module.model.list_obs_plate_vars()` - to keep all model-specific variables in one place. """ samples = dict() @@ -407,7 +399,7 @@ def sample_posterior( batch_size: Optional[int] = None, return_observed: bool = False, return_samples: bool = False, - summary_fun: Optional[dict] = None, + summary_fun: Optional[Dict[str, Callable]] = None, ): """ Summarise posterior distribution. @@ -432,17 +424,26 @@ def sample_posterior( Return samples in addition to sample mean, 5%/95% quantile and SD? summary_fun a dict in the form {"means": np.mean, "std": np.std} which specifies posterior distribution - summaries to compute and which names to use. + summaries to compute and which names to use. See below for default returns. Returns ------- - Posterior distribution samples, a dictionary with elements as follows, - containing dictionaries of numpy arrays for each variable: + Dict[str, Dict[str, np.array]] + Posterior distribution samples, a dictionary with elements as follows, + containing dictionaries of numpy arrays for each variable: post_sample_means - mean of the distribution for each variable; post_sample_q05 - 5% quantile; post_sample_q95 - 95% quantile; - post_sample_sds - standard deviation - posterior_samples - samples for each variable as numpy arrays of shape `(n_samples, ...)` (Optional) + post_sample_sds - standard deviation; + posterior_samples - samples for each variable as numpy arrays of shape `(n_samples, ...)` (Optional). + + Notes + ----- + Note for developers: requires overwritten :func:`~scvi.module.base.PyroBaseModuleClass.list_obs_plate_vars` method. + which lists observation/minibatch plate name and variables. + See :func:`~scvi.module.base.PyroBaseModuleClass.list_obs_plate_vars` for details of the variables it should contain. + This dictionary can be returned by model class method `self.module.model.list_obs_plate_vars()` + to keep all model-specific variables in one place. """ # sample using minibatches (if full data, data is moved to GPU only once anyway) From 5a53fc95e7e1aad1d4189ee312e870d8ceeaec89 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 31 May 2021 20:48:26 +0100 Subject: [PATCH 37/50] updated docs --- scvi/model/base/_pyromixin.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index d1be9b476d..60fc612940 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -431,11 +431,11 @@ def sample_posterior( Dict[str, Dict[str, np.array]] Posterior distribution samples, a dictionary with elements as follows, containing dictionaries of numpy arrays for each variable: - post_sample_means - mean of the distribution for each variable; - post_sample_q05 - 5% quantile; - post_sample_q95 - 95% quantile; - post_sample_sds - standard deviation; - posterior_samples - samples for each variable as numpy arrays of shape `(n_samples, ...)` (Optional). + 1) post_sample_means - mean of the distribution for each variable; + 2) post_sample_q05 - 5% quantile; + 3) post_sample_q95 - 95% quantile; + 4) post_sample_sds - standard deviation; + 5) posterior_samples - samples for each variable as numpy arrays of shape `(n_samples, ...)` (Optional). Notes ----- From 0eb4a16ed7c9e796fdbd227d614dafbb25210392 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 31 May 2021 20:51:54 +0100 Subject: [PATCH 38/50] updated docs --- scvi/model/base/_pyromixin.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 60fc612940..a9e24c5004 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -428,14 +428,16 @@ def sample_posterior( Returns ------- - Dict[str, Dict[str, np.array]] - Posterior distribution samples, a dictionary with elements as follows, - containing dictionaries of numpy arrays for each variable: - 1) post_sample_means - mean of the distribution for each variable; - 2) post_sample_q05 - 5% quantile; - 3) post_sample_q95 - 95% quantile; - 4) post_sample_sds - standard deviation; - 5) posterior_samples - samples for each variable as numpy arrays of shape `(n_samples, ...)` (Optional). + post_sample_means: Dict[str, np.array] + Mean of the posterior distribution for each variable, a dictionary of numpy arrays for each variable; + post_sample_q05: Dict[str, np.array] + 5% quantile of the posterior distribution for each variable; + post_sample_q05: Dict[str, np.array] + 95% quantile of the posterior distribution for each variable; + post_sample_q05: Dict[str, np.array] + Standard deviation of the posterior distribution for each variable; + posterior_samples: Dict[str, np.array] + Posterior distribution samples for each variable as numpy arrays of shape `(n_samples, ...)` (Optional). Notes ----- From 4bb2d5bae8c2ecc3df100bfd50dfb5f6416c4315 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 31 May 2021 20:52:37 +0100 Subject: [PATCH 39/50] updated docs --- scvi/model/base/_pyromixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index a9e24c5004..c4550d5e20 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -436,7 +436,7 @@ def sample_posterior( 95% quantile of the posterior distribution for each variable; post_sample_q05: Dict[str, np.array] Standard deviation of the posterior distribution for each variable; - posterior_samples: Dict[str, np.array] + posterior_samples: Optional[Dict[str, np.array]] Posterior distribution samples for each variable as numpy arrays of shape `(n_samples, ...)` (Optional). Notes From 9c38f005aed5986604f4658cb1a2b742d46e0506 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Tue, 1 Jun 2021 02:37:42 +0100 Subject: [PATCH 40/50] renamed sts -> stds --- scvi/model/base/_pyromixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index c4550d5e20..da0eb750a9 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -465,7 +465,7 @@ def sample_posterior( if summary_fun is None: summary_fun = { "means": np.mean, - "sts": np.std, + "stds": np.std, "q05": lambda x, axis: np.quantile(x, 0.05, axis=axis), "q95": lambda x, axis: np.quantile(x, 0.95, axis=axis), } From f9652f27202dda65e9065d3636368b19ef4c8108 Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Sat, 19 Jun 2021 11:38:07 -0700 Subject: [PATCH 41/50] fix model history issue for pyro on load --- scvi/model/base/_base_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scvi/model/base/_base_model.py b/scvi/model/base/_base_model.py index bdbd35ee2d..7ffb3a32fb 100644 --- a/scvi/model/base/_base_model.py +++ b/scvi/model/base/_base_model.py @@ -344,8 +344,10 @@ def load( model.module.load_state_dict(model_state_dict) except RuntimeError as err: if isinstance(model.module, PyroBaseModuleClass): + old_history = model.history_ logger.info("Preparing underlying module for load") model.train(max_steps=1) + model.history_ = old_history pyro.clear_param_store() model.module.load_state_dict(model_state_dict) else: From 9148c5db630583116623944ee1b2bc1e837789a8 Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Sat, 19 Jun 2021 11:38:29 -0700 Subject: [PATCH 42/50] documentation --- scvi/model/base/_pyromixin.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index da0eb750a9..8bdf8f7df1 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -18,13 +18,14 @@ class PyroJitGuideWarmup(Callback): - def __init__(self, train_dl) -> None: - """ - A callback to warmup a Pyro guide. + """ + A callback to warmup a Pyro guide. - This helps initialize all the relevant parameters by running - one minibatch through the Pyro model. - """ + This helps initialize all the relevant parameters by running + one minibatch through the Pyro model. + """ + + def __init__(self, train_dl: AnnDataLoader) -> None: super().__init__() self.dl = train_dl @@ -80,12 +81,13 @@ def train( Size of the test set. If `None`, defaults to 1 - `train_size`. If `train_size + validation_size < 1`, the remaining cells belong to a test set. batch_size - Minibatch size to use during training. If `None`, no minibatching occurs and all data is copied to device (e.g., GPU). + Minibatch size to use during training. If `None`, no minibatching occurs and all + data is copied to device (e.g., GPU). early_stopping Perform early stopping. Additional arguments can be passed in `**kwargs`. See :class:`~scvi.train.Trainer` for further options. lr - Optimiser learning rate (default optimiser is Pyro ClippedAdam). + Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). Specifying optimiser via plan_kwargs overrides this choice of lr. plan_kwargs Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to @@ -428,15 +430,15 @@ def sample_posterior( Returns ------- - post_sample_means: Dict[str, np.array] + post_sample_means: Dict[str, :class:`np.ndarray`] Mean of the posterior distribution for each variable, a dictionary of numpy arrays for each variable; - post_sample_q05: Dict[str, np.array] + post_sample_q05: Dict[str, :class:`np.ndarray`] 5% quantile of the posterior distribution for each variable; - post_sample_q05: Dict[str, np.array] + post_sample_q05: Dict[str, :class:`np.ndarray`] 95% quantile of the posterior distribution for each variable; - post_sample_q05: Dict[str, np.array] + post_sample_q05: Dict[str, :class:`np.ndarray`] Standard deviation of the posterior distribution for each variable; - posterior_samples: Optional[Dict[str, np.array]] + posterior_samples: Optional[Dict[str, :class:`np.ndarray`]] Posterior distribution samples for each variable as numpy arrays of shape `(n_samples, ...)` (Optional). Notes From 611448e8b6dc3c16b513fa39c3f0d1717c03b198 Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Sat, 19 Jun 2021 11:38:50 -0700 Subject: [PATCH 43/50] elbo should be average --- scvi/train/_trainingplans.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/scvi/train/_trainingplans.py b/scvi/train/_trainingplans.py index a1cf5e5c8c..bfbe1a60f3 100644 --- a/scvi/train/_trainingplans.py +++ b/scvi/train/_trainingplans.py @@ -663,8 +663,11 @@ def training_step(self, batch, batch_idx): def training_epoch_end(self, outputs): elbo = 0 + n = 0 for out in outputs: elbo += out["loss"] + n += 1 + elbo /= n self.log("elbo_train", elbo, prog_bar=True) def configure_optimizers(self): From e9b1e01e32f2e5beff2aae07379d83fab3703282 Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Sat, 19 Jun 2021 11:44:27 -0700 Subject: [PATCH 44/50] codacy --- scvi/model/base/_pyromixin.py | 3 +-- scvi/module/base/_base_module.py | 3 ++- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 8bdf8f7df1..070dc77113 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -35,7 +35,6 @@ def on_train_start(self, trainer, pl_module): Also device agnostic. """ - # warmup guide for JIT pyro_guide = pl_module.module.guide for tensors in self.dl: @@ -322,7 +321,7 @@ def _posterior_samples_minibatch( samples = dict() - gpus, device = parse_use_gpu_arg(use_gpu) + _, device = parse_use_gpu_arg(use_gpu) batch_size = batch_size if batch_size is not None else settings.batch_size diff --git a/scvi/module/base/_base_module.py b/scvi/module/base/_base_module.py index 556de46df8..abd99f11a2 100644 --- a/scvi/module/base/_base_module.py +++ b/scvi/module/base/_base_module.py @@ -263,7 +263,8 @@ def guide(self): @property def list_obs_plate_vars(self): - """Model annotation for minibatch training with pyro plate. + """ + Model annotation for minibatch training with pyro plate. A dictionary with: 1. "name" - the name of observation/minibatch plate; From 7238244b162978a4407f29dd60c084e0bdc02b20 Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Sat, 19 Jun 2021 12:15:49 -0700 Subject: [PATCH 45/50] codacy/docs --- scvi/model/base/_pyromixin.py | 1 - scvi/train/_trainingplans.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 070dc77113..5a14ccf5b2 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -173,7 +173,6 @@ def _get_one_posterior_sample( ------- Dictionary with a sample for each variable """ - guide_trace = poutine.trace(self.module.guide).get_trace(*args, **kwargs) model_trace = poutine.trace( poutine.replay(self.module.model, guide_trace) diff --git a/scvi/train/_trainingplans.py b/scvi/train/_trainingplans.py index bfbe1a60f3..8d04d74cb5 100644 --- a/scvi/train/_trainingplans.py +++ b/scvi/train/_trainingplans.py @@ -585,7 +585,7 @@ class PyroTrainingPlan(pl.LightningModule): Parameters ---------- pyro_module - An instance of :class:`~scvi.compose.PyroBaseModuleClass`. This object + An instance of :class:`~scvi.module.base.PyroBaseModuleClass`. This object should have callable `model` and `guide` attributes or methods. loss_fn A Pyro loss. Should be a subclass of :class:`~pyro.infer.ELBO`. From e15bf397835fb951aedb4facdb4492ed600da1d3 Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Sat, 19 Jun 2021 12:21:08 -0700 Subject: [PATCH 46/50] show attrs in docs --- docs/_templates/class_no_inherited.rst | 11 +++++++++++ scvi/model/base/_pyromixin.py | 6 +++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/docs/_templates/class_no_inherited.rst b/docs/_templates/class_no_inherited.rst index 5d7a8fad89..f02bae3db7 100644 --- a/docs/_templates/class_no_inherited.rst +++ b/docs/_templates/class_no_inherited.rst @@ -11,6 +11,17 @@ {% if methods %} .. rubric:: Methods + + .. autosummary:: + :toctree: . + {% for item in attributes %} + {%- if item not in inherited_members%} + ~{{ fullname }}.{{ item }} + {%- endif -%} + {%- endfor %} + {% endif %} + {% endblock %} + .. autosummary:: :toctree: . {% for item in methods %} diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 5a14ccf5b2..2f5c08360b 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -441,10 +441,10 @@ def sample_posterior( Notes ----- - Note for developers: requires overwritten :func:`~scvi.module.base.PyroBaseModuleClass.list_obs_plate_vars` method. + Note for developers: requires overwritten :attr:`~scvi.module.base.PyroBaseModuleClass.list_obs_plate_vars` property. which lists observation/minibatch plate name and variables. - See :func:`~scvi.module.base.PyroBaseModuleClass.list_obs_plate_vars` for details of the variables it should contain. - This dictionary can be returned by model class method `self.module.model.list_obs_plate_vars()` + See :attr:`~scvi.module.base.PyroBaseModuleClass.list_obs_plate_vars` for details of the variables it should contain. + This dictionary can be returned by model class property `self.module.model.list_obs_plate_vars` to keep all model-specific variables in one place. """ From 08cc1dea27c0eaf7a75e3fde0f475a9d51ce5d58 Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Sat, 19 Jun 2021 12:39:47 -0700 Subject: [PATCH 47/50] sphinx copy button version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6af8d214a4..7ccf184ac3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ sphinx = {version = ">=3.4", optional = true} sphinx-autodoc-typehints = {version = "*", optional = true} sphinx-gallery = {version = ">0.6", optional = true} sphinx-tabs = {version = "*", optional = true} -sphinx_copybutton = {version = "*", optional = true} +sphinx_copybutton = {version = "<=0.3.1", optional = true} torch = ">=1.8.0" tqdm = ">=4.56.0" typing_extensions = {version = "*", python = "<3.8"} From e4c30727063ab1b4bb39a5b14923a00708b1a043 Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Sat, 19 Jun 2021 12:50:04 -0700 Subject: [PATCH 48/50] other doc fixes --- scvi/model/base/_base_model.py | 2 ++ scvi/model/base/_pyromixin.py | 9 +-------- scvi/module/base/_base_module.py | 5 ++++- tests/models/test_pyro.py | 10 +++++++--- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/scvi/model/base/_base_model.py b/scvi/model/base/_base_model.py index 7ffb3a32fb..62d67346b6 100644 --- a/scvi/model/base/_base_model.py +++ b/scvi/model/base/_base_model.py @@ -99,6 +99,8 @@ def _make_data_loader( Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. shuffle Whether observations are shuffled each iteration though + data_loader_class + Class to use for data loader data_loader_kwargs Kwargs to the class-specific data loader class """ diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index 2f5c08360b..4a52f9bad0 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -226,7 +226,6 @@ def _get_posterior_samples( Dictionary with array of samples for each variable dictionary {variable_name: [array with samples in 0 dimension]} """ - samples = self._get_one_posterior_sample( args, kwargs, return_sites=return_sites, sample_observed=return_observed ) @@ -250,10 +249,7 @@ def _get_posterior_samples( return {k: np.array(v) for k, v in samples.items()} def _get_obs_plate_return_sites(self, return_sites, obs_plate_sites): - """ - Check return_sites for overlap with observation/minibatch plate sites. - """ - + """Check return_sites for overlap with observation/minibatch plate sites.""" # check whether any variable requested in return_sites are in obs_plate if return_sites is not None: return_sites = np.array(return_sites) @@ -282,7 +278,6 @@ def _get_obs_plate_sites(self, args, kwargs): ------- Dictionary with keys corresponding to site names and values to plate dimension. """ - plate_name = self.module.list_obs_plate_vars["name"] # find plate dimension @@ -317,7 +312,6 @@ def _posterior_samples_minibatch( ------- dictionary {variable_name: [array with samples in 0 dimension]} """ - samples = dict() _, device = parse_use_gpu_arg(use_gpu) @@ -447,7 +441,6 @@ def sample_posterior( This dictionary can be returned by model class property `self.module.model.list_obs_plate_vars` to keep all model-specific variables in one place. """ - # sample using minibatches (if full data, data is moved to GPU only once anyway) samples = self._posterior_samples_minibatch( use_gpu=use_gpu, diff --git a/scvi/module/base/_base_module.py b/scvi/module/base/_base_module.py index abd99f11a2..9d1a55be8a 100644 --- a/scvi/module/base/_base_module.py +++ b/scvi/module/base/_base_module.py @@ -176,6 +176,7 @@ def inference( This function should return a dictionary with str keys and :class:`~torch.Tensor` values. """ + pass @abstractmethod def generative(self, *args, **kwargs) -> dict: @@ -187,6 +188,7 @@ def generative(self, *args, **kwargs) -> dict: This function should return a dictionary with str keys and :class:`~torch.Tensor` values. """ + pass @abstractmethod def loss(self, *args, **kwargs) -> LossRecorder: @@ -198,10 +200,12 @@ def loss(self, *args, **kwargs) -> LossRecorder: This function should return an object of type :class:`~scvi.module.base.LossRecorder`. """ + pass @abstractmethod def sample(self, *args, **kwargs): """Generate samples from the learned model.""" + pass def _get_dict_if_none(param): @@ -310,7 +314,6 @@ def create_predictive( in an outermost `plate` messenger. Note that this requires that the model has all batch dims correctly annotated via :class:`~pyro.plate`. Default is `False`. """ - if model is None: model = self.model if guide is None: diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index cfa7066961..909b9bd389 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -49,12 +49,16 @@ def __init__(self, in_features, out_features, per_cell_weight=False): ) def create_plates(self, x, y, ind_x): - """Function for creating plates is needed when using AutoGuides - and should have the same call signature as model""" + """ + Function for creating plates is needed when using AutoGuides. + + Should have the same call signature as model. + """ return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=ind_x) def list_obs_plate_vars(self): - """Model annotation for minibatch training with pyro plate. + """ + Model annotation for minibatch training with pyro plate. A dictionary with: 1. "name" - the name of observation/minibatch plate; From ea058990b6b9b9bc84ba4d1559769755c964e871 Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Sat, 19 Jun 2021 12:51:22 -0700 Subject: [PATCH 49/50] fix class template --- docs/_templates/class_no_inherited.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/_templates/class_no_inherited.rst b/docs/_templates/class_no_inherited.rst index f02bae3db7..fd0d55add3 100644 --- a/docs/_templates/class_no_inherited.rst +++ b/docs/_templates/class_no_inherited.rst @@ -13,7 +13,7 @@ .. autosummary:: - :toctree: . + :toctree: . {% for item in attributes %} {%- if item not in inherited_members%} ~{{ fullname }}.{{ item }} From 52759bbab94f25c1e7666d62c987399e6b1d18cb Mon Sep 17 00:00:00 2001 From: adamgayoso Date: Sat, 19 Jun 2021 13:02:28 -0700 Subject: [PATCH 50/50] template --- docs/_templates/class_no_inherited.rst | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/_templates/class_no_inherited.rst b/docs/_templates/class_no_inherited.rst index fd0d55add3..55fd5f50e3 100644 --- a/docs/_templates/class_no_inherited.rst +++ b/docs/_templates/class_no_inherited.rst @@ -7,10 +7,9 @@ .. autoclass:: {{ objname }} :show-inheritance: - {% block methods %} - {% if methods %} - .. rubric:: Methods - + {% block attributes %} + {% if attributes %} + .. rubric:: Attributes .. autosummary:: :toctree: . @@ -22,6 +21,11 @@ {% endif %} {% endblock %} + + {% block methods %} + {% if methods %} + .. rubric:: Methods + .. autosummary:: :toctree: . {% for item in methods %}