Skip to content

Commit

Permalink
Backport PR scverse#1242: Check if function based pyro module for `us…
Browse files Browse the repository at this point in the history
…e_kl_weight`
  • Loading branch information
justjhong authored and meeseeksmachine committed Oct 28, 2021
1 parent 6584927 commit 63460f1
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 3 deletions.
11 changes: 8 additions & 3 deletions scvi/train/_trainingplans.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pyro
import pytorch_lightning as pl
import torch
from pyro.nn import PyroModule
from torch.optim.lr_scheduler import ReduceLROnPlateau

from scvi import _CONSTANTS
Expand Down Expand Up @@ -656,9 +657,13 @@ def __init__(
self.pyro_guide = self.module.guide
self.pyro_model = self.module.model

self.use_kl_weight = (
"kl_weight" in signature(self.pyro_model.forward).parameters
)
self.use_kl_weight = False
if isinstance(self.pyro_model, PyroModule):
self.use_kl_weight = (
"kl_weight" in signature(self.pyro_model.forward).parameters
)
elif callable(self.pyro_model):
self.use_kl_weight = "kl_weight" in signature(self.pyro_model).parameters

self.svi = pyro.infer.SVI(
model=self.pyro_model,
Expand Down
103 changes: 103 additions & 0 deletions tests/models/test_pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PyroSviTrainMixin,
)
from scvi.module.base import PyroBaseModuleClass
from scvi.nn import DecoderSCVI, Encoder
from scvi.train import PyroTrainingPlan, Trainer


Expand Down Expand Up @@ -386,6 +387,108 @@ def test_pyro_bayesian_train_sample_mixin_with_local_full_data():
)


class FunctionBasedPyroModule(PyroBaseModuleClass):
def __init__(self, n_input: int, n_latent: int, n_hidden: int, n_layers: int):

super().__init__()
self.n_input = n_input
self.n_latent = n_latent
self.epsilon = 5.0e-3
# z encoder goes from the n_input-dimensional data to an n_latent-d
# latent space representation
self.encoder = Encoder(
n_input,
n_latent,
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=0.1,
)
# decoder goes from n_latent-dimensional space to n_input-d data
self.decoder = DecoderSCVI(
n_latent,
n_input,
n_layers=n_layers,
n_hidden=n_hidden,
)
# This gene-level parameter modulates the variance of the observation distribution
self.px_r = torch.nn.Parameter(torch.ones(self.n_input))

@staticmethod
def _get_fn_args_from_batch(tensor_dict):
x = tensor_dict[_CONSTANTS.X_KEY]
log_library = torch.log(torch.sum(x, dim=1, keepdim=True) + 1e-6)
return (x, log_library), {}

def model(self, x, log_library):
# register PyTorch module `decoder` with Pyro
pyro.module("scvi", self)
with pyro.plate("data", x.shape[0]):
# setup hyperparameters for prior p(z)
z_loc = x.new_zeros(torch.Size((x.shape[0], self.n_latent)))
z_scale = x.new_ones(torch.Size((x.shape[0], self.n_latent)))
# sample from prior (value will be sampled by guide when computing the ELBO)
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
# decode the latent code z
px_scale, _, px_rate, px_dropout = self.decoder("gene", z, log_library)
# build count distribution
nb_logits = (px_rate + self.epsilon).log() - (
self.px_r.exp() + self.epsilon
).log()
x_dist = dist.ZeroInflatedNegativeBinomial(
gate_logits=px_dropout, total_count=self.px_r.exp(), logits=nb_logits
)
# score against actual counts
pyro.sample("obs", x_dist.to_event(1), obs=x)

def guide(self, x, log_library):
# define the guide (i.e. variational distribution) q(z|x)
pyro.module("scvi", self)
with pyro.plate("data", x.shape[0]):
# use the encoder to get the parameters used to define q(z|x)
x_ = torch.log(1 + x)
z_loc, z_scale, _ = self.encoder(x_)
# sample the latent code z
pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))


class FunctionBasedPyroModel(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass):
def __init__(
self,
adata: AnnData,
):
# in case any other model was created before that shares the same parameter names.
clear_param_store()

super().__init__(adata)

self.module = FunctionBasedPyroModule(
n_input=adata.n_vars,
n_hidden=32,
n_latent=5,
n_layers=1,
)
self._model_summary_string = "FunctionBasedPyroModel"
self.init_params_ = self._get_init_params(locals())

@staticmethod
def setup_anndata(
adata: AnnData,
) -> Optional[AnnData]:
pass


def test_function_based_pyro_module():
use_gpu = torch.cuda.is_available()
adata = synthetic_iid()
mod = FunctionBasedPyroModel(adata)
mod.train(
max_epochs=1,
batch_size=256,
lr=0.01,
use_gpu=use_gpu,
)


def test_lda_model():
use_gpu = torch.cuda.is_available()
n_topics = 5
Expand Down

0 comments on commit 63460f1

Please sign in to comment.