diff --git a/scvi/train/_trainingplans.py b/scvi/train/_trainingplans.py index ab7be08ae1..26530ecb14 100644 --- a/scvi/train/_trainingplans.py +++ b/scvi/train/_trainingplans.py @@ -4,6 +4,7 @@ import pyro import pytorch_lightning as pl import torch +from pyro.nn import PyroModule from torch.optim.lr_scheduler import ReduceLROnPlateau from scvi import _CONSTANTS @@ -656,9 +657,13 @@ def __init__( self.pyro_guide = self.module.guide self.pyro_model = self.module.model - self.use_kl_weight = ( - "kl_weight" in signature(self.pyro_model.forward).parameters - ) + self.use_kl_weight = False + if isinstance(self.pyro_model, PyroModule): + self.use_kl_weight = ( + "kl_weight" in signature(self.pyro_model.forward).parameters + ) + elif callable(self.pyro_model): + self.use_kl_weight = "kl_weight" in signature(self.pyro_model).parameters self.svi = pyro.infer.SVI( model=self.pyro_model, diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index 21a7c54739..9a5606b0be 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -22,6 +22,7 @@ PyroSviTrainMixin, ) from scvi.module.base import PyroBaseModuleClass +from scvi.nn import DecoderSCVI, Encoder from scvi.train import PyroTrainingPlan, Trainer @@ -386,6 +387,108 @@ def test_pyro_bayesian_train_sample_mixin_with_local_full_data(): ) +class FunctionBasedPyroModule(PyroBaseModuleClass): + def __init__(self, n_input: int, n_latent: int, n_hidden: int, n_layers: int): + + super().__init__() + self.n_input = n_input + self.n_latent = n_latent + self.epsilon = 5.0e-3 + # z encoder goes from the n_input-dimensional data to an n_latent-d + # latent space representation + self.encoder = Encoder( + n_input, + n_latent, + n_layers=n_layers, + n_hidden=n_hidden, + dropout_rate=0.1, + ) + # decoder goes from n_latent-dimensional space to n_input-d data + self.decoder = DecoderSCVI( + n_latent, + n_input, + n_layers=n_layers, + n_hidden=n_hidden, + ) + # This gene-level parameter modulates the variance of the observation distribution + self.px_r = torch.nn.Parameter(torch.ones(self.n_input)) + + @staticmethod + def _get_fn_args_from_batch(tensor_dict): + x = tensor_dict[_CONSTANTS.X_KEY] + log_library = torch.log(torch.sum(x, dim=1, keepdim=True) + 1e-6) + return (x, log_library), {} + + def model(self, x, log_library): + # register PyTorch module `decoder` with Pyro + pyro.module("scvi", self) + with pyro.plate("data", x.shape[0]): + # setup hyperparameters for prior p(z) + z_loc = x.new_zeros(torch.Size((x.shape[0], self.n_latent))) + z_scale = x.new_ones(torch.Size((x.shape[0], self.n_latent))) + # sample from prior (value will be sampled by guide when computing the ELBO) + z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) + # decode the latent code z + px_scale, _, px_rate, px_dropout = self.decoder("gene", z, log_library) + # build count distribution + nb_logits = (px_rate + self.epsilon).log() - ( + self.px_r.exp() + self.epsilon + ).log() + x_dist = dist.ZeroInflatedNegativeBinomial( + gate_logits=px_dropout, total_count=self.px_r.exp(), logits=nb_logits + ) + # score against actual counts + pyro.sample("obs", x_dist.to_event(1), obs=x) + + def guide(self, x, log_library): + # define the guide (i.e. variational distribution) q(z|x) + pyro.module("scvi", self) + with pyro.plate("data", x.shape[0]): + # use the encoder to get the parameters used to define q(z|x) + x_ = torch.log(1 + x) + z_loc, z_scale, _ = self.encoder(x_) + # sample the latent code z + pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) + + +class FunctionBasedPyroModel(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass): + def __init__( + self, + adata: AnnData, + ): + # in case any other model was created before that shares the same parameter names. + clear_param_store() + + super().__init__(adata) + + self.module = FunctionBasedPyroModule( + n_input=adata.n_vars, + n_hidden=32, + n_latent=5, + n_layers=1, + ) + self._model_summary_string = "FunctionBasedPyroModel" + self.init_params_ = self._get_init_params(locals()) + + @staticmethod + def setup_anndata( + adata: AnnData, + ) -> Optional[AnnData]: + pass + + +def test_function_based_pyro_module(): + use_gpu = torch.cuda.is_available() + adata = synthetic_iid() + mod = FunctionBasedPyroModel(adata) + mod.train( + max_epochs=1, + batch_size=256, + lr=0.01, + use_gpu=use_gpu, + ) + + def test_lda_model(): use_gpu = torch.cuda.is_available() n_topics = 5