Skip to content

Commit

Permalink
Deprecate LossRecorder, create LossOutput dataclass alternative (#…
Browse files Browse the repository at this point in the history
…1749)

* better optimizer api

* better optimizer api

* typing

* refactor lossrecorder

* finalize

* more efficient metrics

* Revert "more efficient metrics"

This reverts commit 5905efb.

* no floats

* faster metrics

* generic sum

* compute more things on init for gpu

* separate types

* fix duplicate code

* cleanup

* works but slow

* improvements

* restore old functionality

* make lossrecorder somewhat compatbile, more ports

* convert more

* fix gimvi

* fix gimvi

* better deprecation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix import

* address comments

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
adamgayoso and pre-commit-ci[bot] authored Oct 25, 2022
1 parent b2b6224 commit 0c3808e
Show file tree
Hide file tree
Showing 20 changed files with 318 additions and 185 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ version = "0.19.0a0"
[tool.poetry.dependencies]
anndata = ">=0.7.5"
black = {version = ">=22.3", optional = true}
chex = "*"
codecov = {version = ">=2.0.8", optional = true}
docrep = ">=0.3.2"
flake8 = {version = ">=3.7.7", optional = true}
Expand Down
7 changes: 4 additions & 3 deletions scvi/external/gimvi/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from scvi import REGISTRY_KEYS
from scvi.distributions import NegativeBinomial, ZeroInflatedNegativeBinomial
from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from scvi.nn import Encoder, MultiDecoder, MultiEncoder, one_hot

torch.backends.cudnn.benchmark = True
Expand Down Expand Up @@ -506,8 +506,9 @@ def loss(
kl_divergence_l = torch.zeros_like(kl_divergence_z)

kl_local = kl_divergence_l + kl_divergence_z
kl_global = torch.tensor(0.0)

loss = torch.mean(reconstruction_loss + kl_weight * kl_local) * x.size(0)

return LossRecorder(loss, reconstruction_loss, kl_local, kl_global)
return LossOutput(
loss=loss, reconstruction_loss=reconstruction_loss, kl_local=kl_local
)
24 changes: 12 additions & 12 deletions scvi/external/gimvi/_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,27 @@ def training_step(self, batch, batch_idx, optimizer_idx=0):
)
if optimizer_idx == 0:
# batch contains both data loader outputs
scvi_loss_objs = []
loss_output_objs = []
n_obs = 0
zs = []
for (i, tensors) in enumerate(batch):
n_obs += tensors[REGISTRY_KEYS.X_KEY].shape[0]
self.loss_kwargs.update(dict(kl_weight=self.kl_weight, mode=i))
inference_kwargs = dict(mode=i)
generative_kwargs = dict(mode=i)
inference_outputs, _, scvi_loss = self.forward(
inference_outputs, _, loss_output = self.forward(
tensors,
loss_kwargs=self.loss_kwargs,
inference_kwargs=inference_kwargs,
generative_kwargs=generative_kwargs,
)
zs.append(inference_outputs["z"])
scvi_loss_objs.append(scvi_loss)
loss_output_objs.append(loss_output)

loss = sum([scl.loss for scl in scvi_loss_objs])
loss = sum([scl.loss for scl in loss_output_objs])
loss /= n_obs
rec_loss = sum([scl.reconstruction_loss.sum() for scl in scvi_loss_objs])
kl = sum([scl.kl_local.sum() for scl in scvi_loss_objs])
rec_loss = sum([scl.reconstruction_loss_sum for scl in loss_output_objs])
kl = sum([scl.kl_local_sum for scl in loss_output_objs])

# fool classifier if doing adversarial training
batch_tensor = [
Expand Down Expand Up @@ -98,18 +98,18 @@ def validation_step(self, batch, batch_idx, dataloader_idx):
self.loss_kwargs.update(dict(kl_weight=self.kl_weight, mode=dataloader_idx))
inference_kwargs = dict(mode=dataloader_idx)
generative_kwargs = dict(mode=dataloader_idx)
_, _, scvi_loss = self.forward(
_, _, loss_output = self.forward(
batch,
loss_kwargs=self.loss_kwargs,
inference_kwargs=inference_kwargs,
generative_kwargs=generative_kwargs,
)
reconstruction_loss = scvi_loss.reconstruction_loss
reconstruction_loss = loss_output.reconstruction_loss_sum
return {
"reconstruction_loss_sum": reconstruction_loss.sum(),
"kl_local_sum": scvi_loss.kl_local.sum(),
"kl_global": scvi_loss.kl_global,
"n_obs": reconstruction_loss.shape[0],
"reconstruction_loss_sum": reconstruction_loss,
"kl_local_sum": loss_output.kl_local_sum,
"kl_global": loss_output.kl_global,
"n_obs": loss_output.n_obs_minibatch,
}

def validation_epoch_end(self, outputs):
Expand Down
7 changes: 4 additions & 3 deletions scvi/external/scar/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from scvi._compat import Literal
from scvi.distributions import NegativeBinomial, Poisson, ZeroInflatedNegativeBinomial
from scvi.module._vae import VAE
from scvi.module.base import LossRecorder, auto_move_data
from scvi.module.base import LossOutput, auto_move_data
from scvi.nn import FCLayers

torch.backends.cudnn.benchmark = True
Expand Down Expand Up @@ -384,5 +384,6 @@ def loss(
kl_local = dict(
kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z
)
kl_global = torch.tensor(0.0)
return LossRecorder(loss, reconst_loss, kl_local, kl_global)
return LossOutput(
loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_local
)
10 changes: 6 additions & 4 deletions scvi/external/stereoscope/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from scvi import REGISTRY_KEYS
from scvi._compat import Literal
from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data


class RNADeconv(BaseModuleClass):
Expand Down Expand Up @@ -110,7 +110,7 @@ def loss(
reconst_loss = -NegativeBinomial(px_rate, logits=px_o).log_prob(x).sum(-1)
loss = torch.sum(scaling_factor * reconst_loss)

return LossRecorder(loss, reconst_loss, torch.zeros((1,)), torch.tensor(0.0))
return LossOutput(loss=loss, reconstruction_loss=reconst_loss)

@torch.inference_mode()
def sample(
Expand Down Expand Up @@ -239,8 +239,10 @@ def loss(
else:
# the original way it is done in Stereoscope; we use this option to show reproducibility of their codebase
loss = torch.sum(reconst_loss) + neg_log_likelihood_prior
return LossRecorder(
loss, reconst_loss, torch.zeros((1,)), neg_log_likelihood_prior
return LossOutput(
loss=loss,
reconstruction_loss=reconst_loss,
kl_global=neg_log_likelihood_prior,
)

@torch.inference_mode()
Expand Down
11 changes: 9 additions & 2 deletions scvi/external/tangram/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax
import jax.numpy as jnp

from scvi.module.base import JaxBaseModuleClass, LossRecorder, flax_configure
from scvi.module.base import JaxBaseModuleClass, LossOutput, flax_configure


class _TANGRAM_REGISTRY_KEYS_NT(NamedTuple):
Expand Down Expand Up @@ -151,4 +151,11 @@ def loss(
total_loss = -expression_term - regularizer_term + count_term + f_reg
total_loss = total_loss + density_term

return LossRecorder(total_loss, expression_term, regularizer_term)
return LossOutput(
loss=total_loss,
n_obs_minibatch=sp.shape[0],
extra_metrics={
"expression_term": expression_term,
"regularizer_term": regularizer_term,
},
)
14 changes: 9 additions & 5 deletions scvi/model/base/_log_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ def compute_elbo(vae, data_loader, feed_labels=True, **kwargs):
for tensors in data_loader:
_, _, scvi_loss = vae(tensors, **kwargs)

recon_loss = scvi_loss.reconstruction_loss
kl_local = scvi_loss.kl_local
elbo += torch.sum(recon_loss + kl_local).item()
recon_loss = scvi_loss.reconstruction_loss_sum
kl_local = scvi_loss.kl_local_sum
elbo += (recon_loss + kl_local).item()

kl_global = scvi_loss.kl_global
kl_global = scvi_loss.kl_global_sum
n_samples = len(data_loader.indices)
elbo += kl_global
return elbo / n_samples
Expand All @@ -41,7 +41,11 @@ def compute_reconstruction_error(vae, data_loader, **kwargs):
for tensors in data_loader:
loss_kwargs = dict(kl_weight=1)
_, _, losses = vae(tensors, loss_kwargs=loss_kwargs)
for key, value in losses._reconstruction_loss.items():
if not isinstance(losses.reconstruction_loss, dict):
rec_loss_dict = {"reconstruction_loss": losses.reconstruction_loss}
else:
rec_loss_dict = losses.reconstruction_loss
for key, value in rec_loss_dict.items():
if key in log_lkl:
log_lkl[key] += torch.sum(value).item()
else:
Expand Down
4 changes: 2 additions & 2 deletions scvi/model/base/_vaemixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Dict, Optional, Sequence, Tuple, Union
from typing import Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -98,7 +98,7 @@ def get_reconstruction_error(
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
) -> Union[float, Dict[str, float]]:
) -> float:
r"""
Return the reconstruction error for the data.
Expand Down
6 changes: 4 additions & 2 deletions scvi/module/_jaxvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from scvi import REGISTRY_KEYS
from scvi.distributions import JaxNegativeBinomialMeanDisp as NegativeBinomial
from scvi.module.base import JaxBaseModuleClass, LossRecorder, flax_configure
from scvi.module.base import JaxBaseModuleClass, LossOutput, flax_configure


class Dense(nn.Dense):
Expand Down Expand Up @@ -222,4 +222,6 @@ def loss(
loss = jnp.mean(reconst_loss + weighted_kl_local)

kl_local = kl_divergence_z
return LossRecorder(loss, reconst_loss, kl_local)
return LossOutput(
loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_local
)
9 changes: 6 additions & 3 deletions scvi/module/_mrdeconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from scvi import REGISTRY_KEYS
from scvi._compat import Literal
from scvi.distributions import NegativeBinomial
from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from scvi.nn import FCLayers


Expand Down Expand Up @@ -309,8 +309,11 @@ def loss(
+ glo_neg_log_likelihood_prior
)

return LossRecorder(
loss, reconst_loss, neg_log_likelihood_prior, glo_neg_log_likelihood_prior
return LossOutput(
loss=loss,
reconstruction_loss=reconst_loss,
kl_local=neg_log_likelihood_prior,
kl_global=glo_neg_log_likelihood_prior,
)

@torch.inference_mode()
Expand Down
5 changes: 2 additions & 3 deletions scvi/module/_multivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ZeroInflatedNegativeBinomial,
)
from scvi.module._peakvae import Decoder as DecoderPeakVI
from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from scvi.nn import DecoderSCVI, Encoder, FCLayers, one_hot

from ._utils import masked_softmax
Expand Down Expand Up @@ -873,8 +873,7 @@ def loss(
loss = torch.mean(recon_loss + weighted_kl_local + kld_paired)

kl_local = dict(kl_divergence_z=kl_div_z)
kl_global = torch.tensor(0.0)
return LossRecorder(loss, recon_loss, kl_local, kl_global)
return LossOutput(loss=loss, reconstruction_loss=recon_loss, kl_local=kl_local)

def get_reconstruction_loss_expression(self, x, px_rate, px_r, px_dropout):
"""Computes the reconstruction loss for the expression data."""
Expand Down
4 changes: 2 additions & 2 deletions scvi/module/_peakvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from scvi import REGISTRY_KEYS
from scvi._compat import Literal
from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from scvi.nn import Encoder, FCLayers


Expand Down Expand Up @@ -335,4 +335,4 @@ def loss(

loss = (rl.sum() + kld * kl_weight).sum()

return LossRecorder(loss, rl, kld, kl_global=torch.tensor(0.0))
return LossOutput(loss=loss, reconstruction_loss=rl, kl_local=kld)
41 changes: 23 additions & 18 deletions scvi/module/_scanvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from scvi import REGISTRY_KEYS
from scvi._compat import Literal
from scvi.module.base import LossRecorder, auto_move_data
from scvi.module.base import LossOutput, auto_move_data
from scvi.nn import Decoder, Encoder

from ._classifier import Classifier
Expand Down Expand Up @@ -300,18 +300,21 @@ def loss(
if labelled_tensors is not None:
classifier_loss = self.classification_loss(labelled_tensors)
loss += classifier_loss * classification_ratio
return LossRecorder(
loss,
reconst_loss,
kl_locals,
classification_loss=classifier_loss,
n_labelled_tensors=labelled_tensors[REGISTRY_KEYS.X_KEY].shape[0],
return LossOutput(
loss=loss,
reconstruction_loss=reconst_loss,
kl_local=kl_locals,
extra_metrics={
"classification_loss": classifier_loss,
"n_labelled_tensors": labelled_tensors[
REGISTRY_KEYS.X_KEY
].shape[0],
},
)
return LossRecorder(
loss,
reconst_loss,
kl_locals,
kl_global=torch.tensor(0.0),
return LossOutput(
loss=loss,
reconstruction_loss=reconst_loss,
kl_local=kl_locals,
)

probs = self.classifier(z1)
Expand All @@ -333,10 +336,12 @@ def loss(
if labelled_tensors is not None:
classifier_loss = self.classification_loss(labelled_tensors)
loss += classifier_loss * classification_ratio
return LossRecorder(
loss,
reconst_loss,
kl_divergence,
classification_loss=classifier_loss,
return LossOutput(
loss=loss,
reconstruction_loss=reconst_loss,
kl_local=kl_divergence,
extra_metrics={"classification_loss": classifier_loss},
)
return LossRecorder(loss, reconst_loss, kl_divergence)
return LossOutput(
loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence
)
8 changes: 5 additions & 3 deletions scvi/module/_totalvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
NegativeBinomialMixture,
ZeroInflatedNegativeBinomial,
)
from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from scvi.nn import DecoderTOTALVI, EncoderTOTALVI, one_hot

torch.backends.cudnn.benchmark = True
Expand Down Expand Up @@ -649,7 +649,9 @@ def loss(
kl_div_back_pro=kl_div_back_pro,
)

return LossRecorder(loss, reconst_losses, kl_local, kl_global=torch.tensor(0.0))
return LossOutput(
loss=loss, reconstruction_loss=reconst_losses, kl_local=kl_local
)

@torch.inference_mode()
def sample(self, tensors, n_samples=1):
Expand Down Expand Up @@ -698,7 +700,7 @@ def marginal_ll(self, tensors, n_mc_samples):
log_pro_back_mean = generative_outputs["log_pro_back_mean"]

# Reconstruction Loss
reconst_loss = losses._reconstruction_loss
reconst_loss = losses.reconstruction_loss
reconst_loss_gene = reconst_loss["reconst_loss_gene"]
reconst_loss_protein = reconst_loss["reconst_loss_protein"]

Expand Down
Loading

0 comments on commit 0c3808e

Please sign in to comment.