diff --git a/docs/release_notes/v0.15.1.rst b/docs/release_notes/v0.15.1.rst index 177d5ac3b5..402b3dfb97 100644 --- a/docs/release_notes/v0.15.1.rst +++ b/docs/release_notes/v0.15.1.rst @@ -6,6 +6,7 @@ Changes ~~~~~~~ - Remove ``labels_key`` from :class:`~scvi.model.MULTIVI` as it is not used in the model (`#1393`_). - Use scvi-tools mean/inv_disp parameterization of negative binomial for :class:`~scvi.model.JaxSCVI` likelihood (`#1386`_). +- Use multiple particles optionally in :class:`~scvi.model.JaxSCVI` (`#1385`_). Contributors ~~~~~~~~~~~~ @@ -18,5 +19,6 @@ Contributors .. _`@watiss`: https://github.com/watiss .. _`#1393`: https://github.com/YosefLab/scvi-tools/pull/1393 +.. _`#1385`: https://github.com/YosefLab/scvi-tools/pull/1385 .. _`#1386`: https://github.com/YosefLab/scvi-tools/pull/1386 diff --git a/scvi/distributions/_negative_binomial.py b/scvi/distributions/_negative_binomial.py index df07fc5806..21b2d47958 100644 --- a/scvi/distributions/_negative_binomial.py +++ b/scvi/distributions/_negative_binomial.py @@ -570,6 +570,10 @@ def __init__( def mean(self): return self._mean + @property + def inverse_dispersion(self): + return self._inverse_dispersion + @validate_sample def log_prob(self, value): # theta is inverse_dispersion diff --git a/scvi/model/_jaxscvi.py b/scvi/model/_jaxscvi.py index 6e23ff7a45..c675205d11 100644 --- a/scvi/model/_jaxscvi.py +++ b/scvi/model/_jaxscvi.py @@ -270,6 +270,8 @@ def get_latent_representation( self, adata: Optional[AnnData] = None, indices: Optional[Sequence[int]] = None, + give_mean: bool = True, + mc_samples: int = 1, batch_size: Optional[int] = None, ) -> np.ndarray: r""" @@ -301,14 +303,19 @@ def get_latent_representation( @jax.jit def _get_val(array_dict): - out = self.bound_module(array_dict) - return out.qz.mean + out = self.bound_module(array_dict, n_samples=mc_samples) + return out latent = [] for array_dict in scdl: - mean = _get_val(array_dict) - latent.append(mean) - latent = jnp.concatenate(latent) + out = _get_val(array_dict) + if give_mean: + z = out.qz.mean + else: + z = out.z + latent.append(z) + concat_axis = 0 if ((mc_samples == 1) or give_mean) else 1 + latent = jnp.concatenate(latent, axis=concat_axis) return np.array(jax.device_get(latent)) diff --git a/scvi/module/_jaxvae.py b/scvi/module/_jaxvae.py index e7753708b4..3fa5fb613f 100644 --- a/scvi/module/_jaxvae.py +++ b/scvi/module/_jaxvae.py @@ -75,7 +75,8 @@ def __call__(self, z: jnp.ndarray, batch: jnp.ndarray, is_training: bool): h = nn.relu(h) h = nn.Dropout(self.dropout_rate)(h, deterministic=not is_training) # skip connection - h = Dense(self.n_hidden)(jnp.concatenate([h, batch], axis=-1)) + h = Dense(self.n_hidden)(h) + h += Dense(self.n_hidden)(batch) h = nn.BatchNorm(momentum=0.99, epsilon=0.001)( h, use_running_average=not is_training ) @@ -91,6 +92,7 @@ class VAEOutput(NamedTuple): kl: jnp.ndarray px: NegativeBinomial qz: dist.Normal + z: jnp.ndarray class JaxVAE(nn.Module): @@ -105,7 +107,9 @@ class JaxVAE(nn.Module): eps: float = 1e-8 @nn.compact - def __call__(self, array_dict: Dict[str, np.ndarray]) -> VAEOutput: + def __call__( + self, array_dict: Dict[str, np.ndarray], n_samples: int = 1 + ) -> VAEOutput: x = array_dict[REGISTRY_KEYS.X_KEY] batch = array_dict[REGISTRY_KEYS.BATCH_KEY] @@ -122,7 +126,8 @@ def __call__(self, array_dict: Dict[str, np.ndarray]) -> VAEOutput: qz = dist.Normal(mean, stddev) z_rng = self.make_rng("z") - z = qz.rsample(z_rng) + sample_shape = () if n_samples == 1 else (n_samples,) + z = qz.rsample(z_rng, sample_shape=sample_shape) rho_unnorm, disp = FlaxDecoder( n_input=self.n_input, dropout_rate=0.0, @@ -142,4 +147,4 @@ def __call__(self, array_dict: Dict[str, np.ndarray]) -> VAEOutput: rec_loss = -px.log_prob(x).sum(-1) kl_div = dist.kl_divergence(qz, dist.Normal(0, 1)).sum(-1) - return VAEOutput(rec_loss, kl_div, px, qz) + return VAEOutput(rec_loss, kl_div, px, qz, z) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 2faf8e7fb7..ffac47ea5a 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -99,7 +99,10 @@ def test_jax_scvi(): model = JaxSCVI(adata, n_latent=n_latent, gene_likelihood="poisson") model.train(1, train_size=0.5) - model.get_latent_representation() + z1 = model.get_latent_representation(give_mean=True, mc_samples=1) + assert z1.ndim == 2 + z2 = model.get_latent_representation(give_mean=False, mc_samples=15) + assert (z2.ndim == 3) and (z2.shape[0] == 15) def test_scvi(save_path):