Skip to content

Commit

Permalink
Backport PR #1385: JAXSCVI multi particles (#1404)
Browse files Browse the repository at this point in the history
Co-authored-by: Pierre Boyeau <[email protected]>
  • Loading branch information
meeseeksmachine and PierreBoyeau authored Mar 7, 2022
1 parent e2ea1a6 commit b162cec
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 10 deletions.
2 changes: 2 additions & 0 deletions docs/release_notes/v0.15.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Changes
~~~~~~~
- Remove ``labels_key`` from :class:`~scvi.model.MULTIVI` as it is not used in the model (`#1393`_).
- Use scvi-tools mean/inv_disp parameterization of negative binomial for :class:`~scvi.model.JaxSCVI` likelihood (`#1386`_).
- Use multiple particles optionally in :class:`~scvi.model.JaxSCVI` (`#1385`_).

Contributors
~~~~~~~~~~~~
Expand All @@ -18,5 +19,6 @@ Contributors
.. _`@watiss`: https://github.com/watiss

.. _`#1393`: https://github.com/YosefLab/scvi-tools/pull/1393
.. _`#1385`: https://github.com/YosefLab/scvi-tools/pull/1385
.. _`#1386`: https://github.com/YosefLab/scvi-tools/pull/1386

4 changes: 4 additions & 0 deletions scvi/distributions/_negative_binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,10 @@ def __init__(
def mean(self):
return self._mean

@property
def inverse_dispersion(self):
return self._inverse_dispersion

@validate_sample
def log_prob(self, value):
# theta is inverse_dispersion
Expand Down
17 changes: 12 additions & 5 deletions scvi/model/_jaxscvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ def get_latent_representation(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
give_mean: bool = True,
mc_samples: int = 1,
batch_size: Optional[int] = None,
) -> np.ndarray:
r"""
Expand Down Expand Up @@ -301,14 +303,19 @@ def get_latent_representation(

@jax.jit
def _get_val(array_dict):
out = self.bound_module(array_dict)
return out.qz.mean
out = self.bound_module(array_dict, n_samples=mc_samples)
return out

latent = []
for array_dict in scdl:
mean = _get_val(array_dict)
latent.append(mean)
latent = jnp.concatenate(latent)
out = _get_val(array_dict)
if give_mean:
z = out.qz.mean
else:
z = out.z
latent.append(z)
concat_axis = 0 if ((mc_samples == 1) or give_mean) else 1
latent = jnp.concatenate(latent, axis=concat_axis)

return np.array(jax.device_get(latent))

Expand Down
13 changes: 9 additions & 4 deletions scvi/module/_jaxvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def __call__(self, z: jnp.ndarray, batch: jnp.ndarray, is_training: bool):
h = nn.relu(h)
h = nn.Dropout(self.dropout_rate)(h, deterministic=not is_training)
# skip connection
h = Dense(self.n_hidden)(jnp.concatenate([h, batch], axis=-1))
h = Dense(self.n_hidden)(h)
h += Dense(self.n_hidden)(batch)
h = nn.BatchNorm(momentum=0.99, epsilon=0.001)(
h, use_running_average=not is_training
)
Expand All @@ -91,6 +92,7 @@ class VAEOutput(NamedTuple):
kl: jnp.ndarray
px: NegativeBinomial
qz: dist.Normal
z: jnp.ndarray


class JaxVAE(nn.Module):
Expand All @@ -105,7 +107,9 @@ class JaxVAE(nn.Module):
eps: float = 1e-8

@nn.compact
def __call__(self, array_dict: Dict[str, np.ndarray]) -> VAEOutput:
def __call__(
self, array_dict: Dict[str, np.ndarray], n_samples: int = 1
) -> VAEOutput:

x = array_dict[REGISTRY_KEYS.X_KEY]
batch = array_dict[REGISTRY_KEYS.BATCH_KEY]
Expand All @@ -122,7 +126,8 @@ def __call__(self, array_dict: Dict[str, np.ndarray]) -> VAEOutput:

qz = dist.Normal(mean, stddev)
z_rng = self.make_rng("z")
z = qz.rsample(z_rng)
sample_shape = () if n_samples == 1 else (n_samples,)
z = qz.rsample(z_rng, sample_shape=sample_shape)
rho_unnorm, disp = FlaxDecoder(
n_input=self.n_input,
dropout_rate=0.0,
Expand All @@ -142,4 +147,4 @@ def __call__(self, array_dict: Dict[str, np.ndarray]) -> VAEOutput:
rec_loss = -px.log_prob(x).sum(-1)
kl_div = dist.kl_divergence(qz, dist.Normal(0, 1)).sum(-1)

return VAEOutput(rec_loss, kl_div, px, qz)
return VAEOutput(rec_loss, kl_div, px, qz, z)
5 changes: 4 additions & 1 deletion tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ def test_jax_scvi():

model = JaxSCVI(adata, n_latent=n_latent, gene_likelihood="poisson")
model.train(1, train_size=0.5)
model.get_latent_representation()
z1 = model.get_latent_representation(give_mean=True, mc_samples=1)
assert z1.ndim == 2
z2 = model.get_latent_representation(give_mean=False, mc_samples=15)
assert (z2.ndim == 3) and (z2.shape[0] == 15)


def test_scvi(save_path):
Expand Down

0 comments on commit b162cec

Please sign in to comment.