From 0c3808e857bfb7be429117ebf61ee0fcf792c718 Mon Sep 17 00:00:00 2001 From: Adam Gayoso Date: Tue, 25 Oct 2022 14:51:03 -0700 Subject: [PATCH] Deprecate `LossRecorder`, create `LossOutput` dataclass alternative (#1749) * better optimizer api * better optimizer api * typing * refactor lossrecorder * finalize * more efficient metrics * Revert "more efficient metrics" This reverts commit 5905efb93892cec83138070658c74ed16646b4fe. * no floats * faster metrics * generic sum * compute more things on init for gpu * separate types * fix duplicate code * cleanup * works but slow * improvements * restore old functionality * make lossrecorder somewhat compatbile, more ports * convert more * fix gimvi * fix gimvi * better deprecation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import * address comments Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 1 + scvi/external/gimvi/_module.py | 7 +- scvi/external/gimvi/_task.py | 24 ++-- scvi/external/scar/_module.py | 7 +- scvi/external/stereoscope/_module.py | 10 +- scvi/external/tangram/_module.py | 11 +- scvi/model/base/_log_likelihood.py | 14 +- scvi/model/base/_vaemixin.py | 4 +- scvi/module/_jaxvae.py | 6 +- scvi/module/_mrdeconv.py | 9 +- scvi/module/_multivae.py | 5 +- scvi/module/_peakvae.py | 4 +- scvi/module/_scanvae.py | 41 +++--- scvi/module/_totalvae.py | 8 +- scvi/module/_vae.py | 11 +- scvi/module/_vaec.py | 6 +- scvi/module/base/__init__.py | 2 + scvi/module/base/_base_module.py | 189 +++++++++++++++++++++------ scvi/train/_metrics.py | 19 ++- scvi/train/_trainingplans.py | 125 +++++++++--------- 20 files changed, 318 insertions(+), 185 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6b50cb1090..fe97c8994e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ version = "0.19.0a0" [tool.poetry.dependencies] anndata = ">=0.7.5" black = {version = ">=22.3", optional = true} +chex = "*" codecov = {version = ">=2.0.8", optional = true} docrep = ">=0.3.2" flake8 = {version = ">=3.7.7", optional = true} diff --git a/scvi/external/gimvi/_module.py b/scvi/external/gimvi/_module.py index 83bb3dc5f8..9b05b79826 100644 --- a/scvi/external/gimvi/_module.py +++ b/scvi/external/gimvi/_module.py @@ -10,7 +10,7 @@ from scvi import REGISTRY_KEYS from scvi.distributions import NegativeBinomial, ZeroInflatedNegativeBinomial -from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data +from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data from scvi.nn import Encoder, MultiDecoder, MultiEncoder, one_hot torch.backends.cudnn.benchmark = True @@ -506,8 +506,9 @@ def loss( kl_divergence_l = torch.zeros_like(kl_divergence_z) kl_local = kl_divergence_l + kl_divergence_z - kl_global = torch.tensor(0.0) loss = torch.mean(reconstruction_loss + kl_weight * kl_local) * x.size(0) - return LossRecorder(loss, reconstruction_loss, kl_local, kl_global) + return LossOutput( + loss=loss, reconstruction_loss=reconstruction_loss, kl_local=kl_local + ) diff --git a/scvi/external/gimvi/_task.py b/scvi/external/gimvi/_task.py index eb0757f227..375fb96693 100644 --- a/scvi/external/gimvi/_task.py +++ b/scvi/external/gimvi/_task.py @@ -31,7 +31,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=0): ) if optimizer_idx == 0: # batch contains both data loader outputs - scvi_loss_objs = [] + loss_output_objs = [] n_obs = 0 zs = [] for (i, tensors) in enumerate(batch): @@ -39,19 +39,19 @@ def training_step(self, batch, batch_idx, optimizer_idx=0): self.loss_kwargs.update(dict(kl_weight=self.kl_weight, mode=i)) inference_kwargs = dict(mode=i) generative_kwargs = dict(mode=i) - inference_outputs, _, scvi_loss = self.forward( + inference_outputs, _, loss_output = self.forward( tensors, loss_kwargs=self.loss_kwargs, inference_kwargs=inference_kwargs, generative_kwargs=generative_kwargs, ) zs.append(inference_outputs["z"]) - scvi_loss_objs.append(scvi_loss) + loss_output_objs.append(loss_output) - loss = sum([scl.loss for scl in scvi_loss_objs]) + loss = sum([scl.loss for scl in loss_output_objs]) loss /= n_obs - rec_loss = sum([scl.reconstruction_loss.sum() for scl in scvi_loss_objs]) - kl = sum([scl.kl_local.sum() for scl in scvi_loss_objs]) + rec_loss = sum([scl.reconstruction_loss_sum for scl in loss_output_objs]) + kl = sum([scl.kl_local_sum for scl in loss_output_objs]) # fool classifier if doing adversarial training batch_tensor = [ @@ -98,18 +98,18 @@ def validation_step(self, batch, batch_idx, dataloader_idx): self.loss_kwargs.update(dict(kl_weight=self.kl_weight, mode=dataloader_idx)) inference_kwargs = dict(mode=dataloader_idx) generative_kwargs = dict(mode=dataloader_idx) - _, _, scvi_loss = self.forward( + _, _, loss_output = self.forward( batch, loss_kwargs=self.loss_kwargs, inference_kwargs=inference_kwargs, generative_kwargs=generative_kwargs, ) - reconstruction_loss = scvi_loss.reconstruction_loss + reconstruction_loss = loss_output.reconstruction_loss_sum return { - "reconstruction_loss_sum": reconstruction_loss.sum(), - "kl_local_sum": scvi_loss.kl_local.sum(), - "kl_global": scvi_loss.kl_global, - "n_obs": reconstruction_loss.shape[0], + "reconstruction_loss_sum": reconstruction_loss, + "kl_local_sum": loss_output.kl_local_sum, + "kl_global": loss_output.kl_global, + "n_obs": loss_output.n_obs_minibatch, } def validation_epoch_end(self, outputs): diff --git a/scvi/external/scar/_module.py b/scvi/external/scar/_module.py index 3cbc504a2c..707ce6da14 100644 --- a/scvi/external/scar/_module.py +++ b/scvi/external/scar/_module.py @@ -7,7 +7,7 @@ from scvi._compat import Literal from scvi.distributions import NegativeBinomial, Poisson, ZeroInflatedNegativeBinomial from scvi.module._vae import VAE -from scvi.module.base import LossRecorder, auto_move_data +from scvi.module.base import LossOutput, auto_move_data from scvi.nn import FCLayers torch.backends.cudnn.benchmark = True @@ -384,5 +384,6 @@ def loss( kl_local = dict( kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z ) - kl_global = torch.tensor(0.0) - return LossRecorder(loss, reconst_loss, kl_local, kl_global) + return LossOutput( + loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_local + ) diff --git a/scvi/external/stereoscope/_module.py b/scvi/external/stereoscope/_module.py index d74545aac0..93c1e9eedd 100644 --- a/scvi/external/stereoscope/_module.py +++ b/scvi/external/stereoscope/_module.py @@ -6,7 +6,7 @@ from scvi import REGISTRY_KEYS from scvi._compat import Literal -from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data +from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data class RNADeconv(BaseModuleClass): @@ -110,7 +110,7 @@ def loss( reconst_loss = -NegativeBinomial(px_rate, logits=px_o).log_prob(x).sum(-1) loss = torch.sum(scaling_factor * reconst_loss) - return LossRecorder(loss, reconst_loss, torch.zeros((1,)), torch.tensor(0.0)) + return LossOutput(loss=loss, reconstruction_loss=reconst_loss) @torch.inference_mode() def sample( @@ -239,8 +239,10 @@ def loss( else: # the original way it is done in Stereoscope; we use this option to show reproducibility of their codebase loss = torch.sum(reconst_loss) + neg_log_likelihood_prior - return LossRecorder( - loss, reconst_loss, torch.zeros((1,)), neg_log_likelihood_prior + return LossOutput( + loss=loss, + reconstruction_loss=reconst_loss, + kl_global=neg_log_likelihood_prior, ) @torch.inference_mode() diff --git a/scvi/external/tangram/_module.py b/scvi/external/tangram/_module.py index 3d5e9c53b8..c418b4e349 100644 --- a/scvi/external/tangram/_module.py +++ b/scvi/external/tangram/_module.py @@ -4,7 +4,7 @@ import jax import jax.numpy as jnp -from scvi.module.base import JaxBaseModuleClass, LossRecorder, flax_configure +from scvi.module.base import JaxBaseModuleClass, LossOutput, flax_configure class _TANGRAM_REGISTRY_KEYS_NT(NamedTuple): @@ -151,4 +151,11 @@ def loss( total_loss = -expression_term - regularizer_term + count_term + f_reg total_loss = total_loss + density_term - return LossRecorder(total_loss, expression_term, regularizer_term) + return LossOutput( + loss=total_loss, + n_obs_minibatch=sp.shape[0], + extra_metrics={ + "expression_term": expression_term, + "regularizer_term": regularizer_term, + }, + ) diff --git a/scvi/model/base/_log_likelihood.py b/scvi/model/base/_log_likelihood.py index c31a9574f4..434d4047dc 100644 --- a/scvi/model/base/_log_likelihood.py +++ b/scvi/model/base/_log_likelihood.py @@ -18,11 +18,11 @@ def compute_elbo(vae, data_loader, feed_labels=True, **kwargs): for tensors in data_loader: _, _, scvi_loss = vae(tensors, **kwargs) - recon_loss = scvi_loss.reconstruction_loss - kl_local = scvi_loss.kl_local - elbo += torch.sum(recon_loss + kl_local).item() + recon_loss = scvi_loss.reconstruction_loss_sum + kl_local = scvi_loss.kl_local_sum + elbo += (recon_loss + kl_local).item() - kl_global = scvi_loss.kl_global + kl_global = scvi_loss.kl_global_sum n_samples = len(data_loader.indices) elbo += kl_global return elbo / n_samples @@ -41,7 +41,11 @@ def compute_reconstruction_error(vae, data_loader, **kwargs): for tensors in data_loader: loss_kwargs = dict(kl_weight=1) _, _, losses = vae(tensors, loss_kwargs=loss_kwargs) - for key, value in losses._reconstruction_loss.items(): + if not isinstance(losses.reconstruction_loss, dict): + rec_loss_dict = {"reconstruction_loss": losses.reconstruction_loss} + else: + rec_loss_dict = losses.reconstruction_loss + for key, value in rec_loss_dict.items(): if key in log_lkl: log_lkl[key] += torch.sum(value).item() else: diff --git a/scvi/model/base/_vaemixin.py b/scvi/model/base/_vaemixin.py index ef7de70bd1..ba5257f4c2 100644 --- a/scvi/model/base/_vaemixin.py +++ b/scvi/model/base/_vaemixin.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, Optional, Sequence, Tuple, Union +from typing import Optional, Sequence, Tuple, Union import numpy as np import torch @@ -98,7 +98,7 @@ def get_reconstruction_error( adata: Optional[AnnData] = None, indices: Optional[Sequence[int]] = None, batch_size: Optional[int] = None, - ) -> Union[float, Dict[str, float]]: + ) -> float: r""" Return the reconstruction error for the data. diff --git a/scvi/module/_jaxvae.py b/scvi/module/_jaxvae.py index d9e085dd19..3e0c0fa630 100644 --- a/scvi/module/_jaxvae.py +++ b/scvi/module/_jaxvae.py @@ -8,7 +8,7 @@ from scvi import REGISTRY_KEYS from scvi.distributions import JaxNegativeBinomialMeanDisp as NegativeBinomial -from scvi.module.base import JaxBaseModuleClass, LossRecorder, flax_configure +from scvi.module.base import JaxBaseModuleClass, LossOutput, flax_configure class Dense(nn.Dense): @@ -222,4 +222,6 @@ def loss( loss = jnp.mean(reconst_loss + weighted_kl_local) kl_local = kl_divergence_z - return LossRecorder(loss, reconst_loss, kl_local) + return LossOutput( + loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_local + ) diff --git a/scvi/module/_mrdeconv.py b/scvi/module/_mrdeconv.py index 4cb15d8bfd..eb92e0bcb3 100644 --- a/scvi/module/_mrdeconv.py +++ b/scvi/module/_mrdeconv.py @@ -7,7 +7,7 @@ from scvi import REGISTRY_KEYS from scvi._compat import Literal from scvi.distributions import NegativeBinomial -from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data +from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data from scvi.nn import FCLayers @@ -309,8 +309,11 @@ def loss( + glo_neg_log_likelihood_prior ) - return LossRecorder( - loss, reconst_loss, neg_log_likelihood_prior, glo_neg_log_likelihood_prior + return LossOutput( + loss=loss, + reconstruction_loss=reconst_loss, + kl_local=neg_log_likelihood_prior, + kl_global=glo_neg_log_likelihood_prior, ) @torch.inference_mode() diff --git a/scvi/module/_multivae.py b/scvi/module/_multivae.py index 191507a3b0..fe296d0b86 100644 --- a/scvi/module/_multivae.py +++ b/scvi/module/_multivae.py @@ -15,7 +15,7 @@ ZeroInflatedNegativeBinomial, ) from scvi.module._peakvae import Decoder as DecoderPeakVI -from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data +from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data from scvi.nn import DecoderSCVI, Encoder, FCLayers, one_hot from ._utils import masked_softmax @@ -873,8 +873,7 @@ def loss( loss = torch.mean(recon_loss + weighted_kl_local + kld_paired) kl_local = dict(kl_divergence_z=kl_div_z) - kl_global = torch.tensor(0.0) - return LossRecorder(loss, recon_loss, kl_local, kl_global) + return LossOutput(loss=loss, reconstruction_loss=recon_loss, kl_local=kl_local) def get_reconstruction_loss_expression(self, x, px_rate, px_r, px_dropout): """Computes the reconstruction loss for the expression data.""" diff --git a/scvi/module/_peakvae.py b/scvi/module/_peakvae.py index 322a81d2ca..12cd8360a6 100644 --- a/scvi/module/_peakvae.py +++ b/scvi/module/_peakvae.py @@ -7,7 +7,7 @@ from scvi import REGISTRY_KEYS from scvi._compat import Literal -from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data +from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data from scvi.nn import Encoder, FCLayers @@ -335,4 +335,4 @@ def loss( loss = (rl.sum() + kld * kl_weight).sum() - return LossRecorder(loss, rl, kld, kl_global=torch.tensor(0.0)) + return LossOutput(loss=loss, reconstruction_loss=rl, kl_local=kld) diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index 4525fe66df..d0863717c6 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -8,7 +8,7 @@ from scvi import REGISTRY_KEYS from scvi._compat import Literal -from scvi.module.base import LossRecorder, auto_move_data +from scvi.module.base import LossOutput, auto_move_data from scvi.nn import Decoder, Encoder from ._classifier import Classifier @@ -300,18 +300,21 @@ def loss( if labelled_tensors is not None: classifier_loss = self.classification_loss(labelled_tensors) loss += classifier_loss * classification_ratio - return LossRecorder( - loss, - reconst_loss, - kl_locals, - classification_loss=classifier_loss, - n_labelled_tensors=labelled_tensors[REGISTRY_KEYS.X_KEY].shape[0], + return LossOutput( + loss=loss, + reconstruction_loss=reconst_loss, + kl_local=kl_locals, + extra_metrics={ + "classification_loss": classifier_loss, + "n_labelled_tensors": labelled_tensors[ + REGISTRY_KEYS.X_KEY + ].shape[0], + }, ) - return LossRecorder( - loss, - reconst_loss, - kl_locals, - kl_global=torch.tensor(0.0), + return LossOutput( + loss=loss, + reconstruction_loss=reconst_loss, + kl_local=kl_locals, ) probs = self.classifier(z1) @@ -333,10 +336,12 @@ def loss( if labelled_tensors is not None: classifier_loss = self.classification_loss(labelled_tensors) loss += classifier_loss * classification_ratio - return LossRecorder( - loss, - reconst_loss, - kl_divergence, - classification_loss=classifier_loss, + return LossOutput( + loss=loss, + reconstruction_loss=reconst_loss, + kl_local=kl_divergence, + extra_metrics={"classification_loss": classifier_loss}, ) - return LossRecorder(loss, reconst_loss, kl_divergence) + return LossOutput( + loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence + ) diff --git a/scvi/module/_totalvae.py b/scvi/module/_totalvae.py index 20d18cb432..79977ff34b 100644 --- a/scvi/module/_totalvae.py +++ b/scvi/module/_totalvae.py @@ -14,7 +14,7 @@ NegativeBinomialMixture, ZeroInflatedNegativeBinomial, ) -from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data +from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data from scvi.nn import DecoderTOTALVI, EncoderTOTALVI, one_hot torch.backends.cudnn.benchmark = True @@ -649,7 +649,9 @@ def loss( kl_div_back_pro=kl_div_back_pro, ) - return LossRecorder(loss, reconst_losses, kl_local, kl_global=torch.tensor(0.0)) + return LossOutput( + loss=loss, reconstruction_loss=reconst_losses, kl_local=kl_local + ) @torch.inference_mode() def sample(self, tensors, n_samples=1): @@ -698,7 +700,7 @@ def marginal_ll(self, tensors, n_mc_samples): log_pro_back_mean = generative_outputs["log_pro_back_mean"] # Reconstruction Loss - reconst_loss = losses._reconstruction_loss + reconst_loss = losses.reconstruction_loss reconst_loss_gene = reconst_loss["reconst_loss_gene"] reconst_loss_protein = reconst_loss["reconst_loss_protein"] diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index 148bbec286..51c7d6446e 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -12,7 +12,7 @@ from scvi._compat import Literal from scvi._types import LatentDataType from scvi.distributions import NegativeBinomial, Poisson, ZeroInflatedNegativeBinomial -from scvi.module.base import BaseLatentModeModuleClass, LossRecorder, auto_move_data +from scvi.module.base import BaseLatentModeModuleClass, LossOutput, auto_move_data from scvi.nn import DecoderSCVI, Encoder, LinearDecoderSCVI, one_hot torch.backends.cudnn.benchmark = True @@ -441,7 +441,7 @@ def loss( generative_outputs["pl"], ).sum(dim=1) else: - kl_divergence_l = 0.0 + kl_divergence_l = torch.tensor(0.0, device=x.device) reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1) @@ -455,8 +455,9 @@ def loss( kl_local = dict( kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z ) - kl_global = torch.tensor(0.0) - return LossRecorder(loss, reconst_loss, kl_local, kl_global) + return LossOutput( + loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_local + ) @torch.inference_mode() def sample( @@ -525,7 +526,7 @@ def marginal_ll(self, tensors, n_mc_samples): library = inference_outputs["library"] # Reconstruction Loss - reconst_loss = losses.reconstruction_loss + reconst_loss = losses.dict_sum(losses.reconstruction_loss) # Log-probabilities p_z = ( diff --git a/scvi/module/_vaec.py b/scvi/module/_vaec.py index 554ddf4db2..40a2491089 100644 --- a/scvi/module/_vaec.py +++ b/scvi/module/_vaec.py @@ -5,7 +5,7 @@ from scvi import REGISTRY_KEYS from scvi.distributions import NegativeBinomial -from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data +from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data from scvi.nn import Encoder, FCLayers torch.backends.cudnn.benchmark = True @@ -177,7 +177,9 @@ def loss( scaling_factor = self.ct_weight[y.long()[:, 0]] loss = torch.mean(scaling_factor * (reconst_loss + kl_weight * kl_divergence_z)) - return LossRecorder(loss, reconst_loss, kl_divergence_z, torch.tensor(0.0)) + return LossOutput( + loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence_z + ) @torch.inference_mode() def sample( diff --git a/scvi/module/base/__init__.py b/scvi/module/base/__init__.py index a843662d8e..bc884fb3e1 100644 --- a/scvi/module/base/__init__.py +++ b/scvi/module/base/__init__.py @@ -2,6 +2,7 @@ BaseLatentModeModuleClass, BaseModuleClass, JaxBaseModuleClass, + LossOutput, LossRecorder, PyroBaseModuleClass, TrainStateWithState, @@ -11,6 +12,7 @@ __all__ = [ "BaseModuleClass", "LossRecorder", + "LossOutput", "PyroBaseModuleClass", "auto_move_data", "flax_configure", diff --git a/scvi/module/base/_base_module.py b/scvi/module/base/_base_module.py index 30bf3d957e..3be7386422 100644 --- a/scvi/module/base/_base_module.py +++ b/scvi/module/base/_base_module.py @@ -1,8 +1,11 @@ from __future__ import annotations +import warnings from abc import abstractmethod +from dataclasses import field from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union +import chex import flax import jax import jax.numpy as jnp @@ -18,7 +21,7 @@ from torch import nn from scvi import settings -from scvi._types import LatentDataType, LossRecord +from scvi._types import LatentDataType, LossRecord, Tensor from scvi.utils._jax import device_selecting_PRNGKey from ._decorators import auto_move_data @@ -33,18 +36,20 @@ class LossRecorder: the components of the ELBO. This may also be used in MLE, MAP, EM methods. The loss is used for backpropagation during inference. The other parameters are used for logging/early stopping during inference. - Parameters ---------- loss Tensor with loss for minibatch. Should be one dimensional with one value. Note that loss should be a :class:`~torch.Tensor` and not the result of ``.item()``. reconstruction_loss - Reconstruction loss for each observation in the minibatch. + Reconstruction loss for each observation in the minibatch. If a tensor, converted to + a dictionary with key "reconstruction_loss" and value as tensor kl_local - KL divergence associated with each observation in the minibatch. + KL divergence associated with each observation in the minibatch. If a tensor, converted to + a dictionary with key "kl_local" and value as tensor kl_global - Global kl divergence term. Should be one dimensional with one value. + Global kl divergence term. Should be one dimensional with one value. If a tensor, converted to + a dictionary with key "kl_global" and value as tensor **kwargs Additional metrics can be passed as keyword arguments and will be available as attributes of the object. @@ -58,56 +63,158 @@ def __init__( kl_global: Optional[LossRecord] = None, **kwargs, ): - - default = ( - torch.tensor(0.0) if isinstance(loss, torch.Tensor) else jnp.array(0.0) - ) - if reconstruction_loss is None: - reconstruction_loss = default - if kl_local is None: - kl_local = default - if kl_global is None: - kl_global = default - - self._loss = loss if isinstance(loss, dict) else dict(loss=loss) - self._reconstruction_loss = ( - reconstruction_loss - if isinstance(reconstruction_loss, dict) - else dict(reconstruction_loss=reconstruction_loss) + warnings.warn( + "LossRecorder is deprecated and will be removed in version 0.20.0. Please use LossOutput", + category=DeprecationWarning, ) - self._kl_local = ( - kl_local if isinstance(kl_local, dict) else dict(kl_local=kl_local) - ) - self._kl_global = ( - kl_global if isinstance(kl_global, dict) else dict(kl_global=kl_global) + self._loss_output = LossOutput( + loss=loss, + reconstruction_loss=reconstruction_loss, + kl_local=kl_local, + kl_global=kl_global, + extra_metrics=kwargs, ) self.extra_metric_attrs = [] for key, value in kwargs.items(): setattr(self, key, value) self.extra_metric_attrs.append(key) - @staticmethod - def _get_dict_sum(dictionary): - total = 0.0 - for value in dictionary.values(): - total += value - return total - @property def loss(self) -> Union[torch.Tensor, jnp.ndarray]: # noqa: D102 - return self._get_dict_sum(self._loss) + return self._loss_output.loss @property def reconstruction_loss(self) -> Union[torch.Tensor, jnp.ndarray]: # noqa: D102 - return self._get_dict_sum(self._reconstruction_loss) + return self.dict_sum(self._loss_output.reconstruction_loss) + + @property + def _reconstruction_loss(self): + return self._loss_output.reconstruction_loss @property def kl_local(self) -> Union[torch.Tensor, jnp.ndarray]: # noqa: D102 - return self._get_dict_sum(self._kl_local) + return self.dict_sum(self._loss_output.kl_local) + + @property + def _kl_local(self): + return self._loss_output.kl_local + + @property + def reconstruction_loss_sum(self) -> Union[torch.Tensor, jnp.ndarray]: # noqa: D102 + return self._loss_output.reconstruction_loss_sum + + @property + def kl_local_sum(self) -> Union[torch.Tensor, jnp.ndarray]: # noqa: D102 + return self._loss_output.kl_local_sum + + @property + def kl_global_sum(self) -> Union[torch.Tensor, jnp.ndarray]: # noqa: D102 + return self._loss_output.kl_global_sum @property def kl_global(self) -> Union[torch.Tensor, jnp.ndarray]: # noqa: D102 - return self._get_dict_sum(self._kl_global) + return self.dict_sum(self._loss_output.kl_global) + + def dict_sum(self, x): + """Wrapper of LossOutput.dict_sum.""" + return self._loss_output.dict_sum(x) + + +@chex.dataclass +class LossOutput: + """ + Loss signature for models. + + This class provides an organized way to record the model loss, as well as + the components of the ELBO. This may also be used in MLE, MAP, EM methods. + The loss is used for backpropagation during inference. The other parameters + are used for logging/early stopping during inference. + + Parameters + ---------- + loss + Tensor with loss for minibatch. Should be one dimensional with one value. + Note that loss should be in an array/tensor and not a float. + reconstruction_loss + Reconstruction loss for each observation in the minibatch. If a tensor, converted to + a dictionary with key "reconstruction_loss" and value as tensor. + kl_local + KL divergence associated with each observation in the minibatch. If a tensor, converted to + a dictionary with key "kl_local" and value as tensor. + kl_global + Global KL divergence term. Should be one dimensional with one value. If a tensor, converted to + a dictionary with key "kl_global" and value as tensor. + extra_metrics + Additional metrics can be passed as arrays/tensors or dictionaries of + arrays/tensors. + n_obs_minibatch + Number of observations in the minibatch. If None, will be inferred from + the shape of the reconstruction_loss tensor. + reconstruction_loss_sum + Sum of the reconstruction loss across the minibatch. Will be computed + automatically. + kl_loca_sum + Sum of the kl_local across the minibatch. Will be computed + automatically. + kl_global_sum + Sum of the kl_global terms. Will be computed automatically. + """ + + loss: LossRecord + reconstruction_loss: Optional[LossRecord] = None + kl_local: Optional[LossRecord] = None + kl_global: Optional[LossRecord] = None + extra_metrics: Optional[Dict[str, Tensor]] = field(default_factory=dict) + n_obs_minibatch: Optional[int] = None + reconstruction_loss_sum: Tensor = field(default=None, init=False) + kl_local_sum: Tensor = field(default=None, init=False) + kl_global_sum: Tensor = field(default=None, init=False) + + def __post_init__(self): + self.loss = self.dict_sum(self.loss) + + if self.n_obs_minibatch is None and self.reconstruction_loss is None: + raise ValueError( + "Must provide either n_obs_minibatch or reconstruction_loss" + ) + + default = 0 * self.loss + if self.reconstruction_loss is None: + self.reconstruction_loss = default + if self.kl_local is None: + self.kl_local = default + if self.kl_global is None: + self.kl_global = default + self.reconstruction_loss = self._as_dict("reconstruction_loss") + self.kl_local = self._as_dict("kl_local") + self.kl_global = self._as_dict("kl_global") + self.reconstruction_loss_sum = self.dict_sum(self.reconstruction_loss).sum() + self.kl_local_sum = self.dict_sum(self.kl_local).sum() + self.kl_global_sum = self.dict_sum(self.kl_global) + + if self.reconstruction_loss is not None and self.n_obs_minibatch is None: + rec_loss = self.reconstruction_loss + self.n_obs_minibatch = list(rec_loss.values())[0].shape[0] + + @staticmethod + def dict_sum(dictionary: Union[Dict[str, Tensor], Tensor]): + """Sum over elements of a dictionary.""" + if isinstance(dictionary, dict): + return sum(dictionary.values()) + else: + return dictionary + + @property + def extra_metrics_keys(self) -> Iterable[str]: + """Keys for extra metrics.""" + return self.extra_metrics.keys() + + def _as_dict(self, attr_name: str): + attr = getattr(self, attr_name) + if isinstance(attr, dict): + return attr + else: + return {attr_name: attr} class BaseModuleClass(nn.Module): @@ -217,14 +324,14 @@ def generative( """ @abstractmethod - def loss(self, *args, **kwargs) -> LossRecorder: + def loss(self, *args, **kwargs) -> LossOutput: """ Compute the loss for a minibatch of data. This function uses the outputs of the inference and generative functions to compute a loss. This many optionally include other penalty terms, which should be computed here. - This function should return an object of type :class:`~scvi.module.base.LossRecorder`. + This function should return an object of type :class:`~scvi.module.base.LossOutput`. """ @abstractmethod @@ -547,14 +654,14 @@ def generative( """ @abstractmethod - def loss(self, *args, **kwargs) -> LossRecorder: + def loss(self, *args, **kwargs) -> LossOutput: """ Compute the loss for a minibatch of data. This function uses the outputs of the inference and generative functions to compute a loss. This many optionally include other penalty terms, which should be computed here. - This function should return an object of type :class:`~scvi.module.base.LossRecorder`. + This function should return an object of type :class:`~scvi.module.base.LossOutput`. """ @property diff --git a/scvi/train/_metrics.py b/scvi/train/_metrics.py index 1f44b2d199..6fa8e1a0a1 100644 --- a/scvi/train/_metrics.py +++ b/scvi/train/_metrics.py @@ -17,15 +17,12 @@ class ElboMetric(Metric): interval The interval over which the metric is computed. If "obs", the metric value per observation is computed. If "batch", the metric value per batch is computed. - dist_sync_on_step - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. **kwargs Keyword args for :class:`torchmetrics.Metric` """ # Needs to be explicitly set to avoid TorchMetrics UserWarning. - full_state_update = True + full_state_update = False _N_OBS_MINIBATCH_KEY = "n_obs_minibatch" def __init__( @@ -33,19 +30,19 @@ def __init__( name: str, mode: Literal["train", "validation"], interval: Literal["obs", "batch"], - dist_sync_on_step: bool = False, **kwargs, ): - super().__init__(dist_sync_on_step=dist_sync_on_step, **kwargs) + super().__init__(**kwargs) self._name = name self._mode = mode self._interval = interval - default_val = torch.tensor(0.0) - self.add_state("elbo_component", default=default_val) - self.add_state("n_obs", default=default_val) - self.add_state("n_batches", default=default_val) + self.add_state( + "elbo_component", default=torch.tensor(0.0), dist_reduce_fx="sum" + ) + self.add_state("n_obs", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("n_batches", default=torch.tensor(0.0), dist_reduce_fx="sum") @property def mode(self): # noqa: D102 @@ -88,7 +85,7 @@ def update( if self._name not in kwargs: raise ValueError(f"Missing {self._name} value in metrics update.") - elbo_component = kwargs[self._name].detach() + elbo_component = kwargs[self._name] self.elbo_component += elbo_component n_obs_minibatch = kwargs[self._N_OBS_MINIBATCH_KEY] diff --git a/scvi/train/_trainingplans.py b/scvi/train/_trainingplans.py index ea874a86ed..8c5274da60 100644 --- a/scvi/train/_trainingplans.py +++ b/scvi/train/_trainingplans.py @@ -1,3 +1,4 @@ +from collections import OrderedDict from functools import partial from inspect import signature from typing import Callable, Dict, Iterable, Optional, Union @@ -11,7 +12,6 @@ import torch from pyro.nn import PyroModule from torch.optim.lr_scheduler import ReduceLROnPlateau -from torchmetrics import MetricCollection from scvi import REGISTRY_KEYS from scvi._compat import Literal @@ -19,6 +19,7 @@ from scvi.module.base import ( BaseModuleClass, JaxBaseModuleClass, + LossOutput, LossRecorder, PyroBaseModuleClass, TrainStateWithState, @@ -27,6 +28,9 @@ from ._metrics import ElboMetric +JaxOptimizerCreator = Callable[[], optax.GradientTransformation] +TorchOptimizerCreator = Callable[[Iterable[torch.Tensor]], torch.optim.Optimizer] + def _compute_kl_weight( epoch: int, @@ -139,9 +143,7 @@ def __init__( module: BaseModuleClass, *, optimizer: Literal["Adam", "AdamW", "Custom"] = "Adam", - optimizer_creator: Optional[ - Callable[[Iterable[torch.Tensor]], torch.optim.Optimizer] - ] = None, + optimizer_creator: Optional[TorchOptimizerCreator] = None, lr: float = 1e-3, weight_decay: float = 1e-6, eps: float = 0.01, @@ -205,8 +207,8 @@ def _create_elbo_metric_components(mode: str, n_total: Optional[int] = None): n = 1 if n_total is None or n_total < 1 else n_total elbo = rec_loss + kl_local + (1 / n) * kl_global elbo.name = f"elbo_{mode}" - collection = MetricCollection( - {metric.name: metric for metric in [elbo, rec_loss, kl_local, kl_global]} + collection = OrderedDict( + [(metric.name, metric) for metric in [elbo, rec_loss, kl_local, kl_global]] ) return elbo, rec_loss, kl_local, kl_global, collection @@ -281,8 +283,8 @@ def forward(self, *args, **kwargs): @torch.inference_mode() def compute_and_log_metrics( self, - loss_recorder: LossRecorder, - metrics: MetricCollection, + loss_recorder: Union[LossRecorder, LossOutput], + metrics: Dict[str, ElboMetric], mode: str, ): """ @@ -298,14 +300,19 @@ def compute_and_log_metrics( Postfix string to add to the metric name of extra metrics """ - rec_loss = loss_recorder.reconstruction_loss - n_obs_minibatch = rec_loss.shape[0] - rec_loss = rec_loss.sum() - kl_local = loss_recorder.kl_local.sum() - kl_global = loss_recorder.kl_global - - # use the torchmetric object for the ELBO - metrics.update( + if isinstance(loss_recorder, LossRecorder): + loss_output = loss_recorder._loss_output + else: + loss_output = loss_recorder + rec_loss = loss_output.reconstruction_loss_sum + n_obs_minibatch = loss_output.n_obs_minibatch + kl_local = loss_output.kl_local_sum + kl_global = loss_output.kl_global_sum + + # Use the torchmetric object for the ELBO + # We only need to update the ELBO metric + # As it's defined as a sum of the other metrics + metrics[f"elbo_{mode}"].update( reconstruction_loss=rec_loss, kl_local=kl_local, kl_global=kl_global, @@ -320,14 +327,14 @@ def compute_and_log_metrics( ) # accumlate extra metrics passed to loss recorder - for extra_metric in loss_recorder.extra_metric_attrs: - met = getattr(loss_recorder, extra_metric) + for key in loss_output.extra_metrics_keys: + met = loss_output.extra_metrics[key] if isinstance(met, torch.Tensor): if met.shape != torch.Size([]): raise ValueError("Extra tracked metrics should be 0-d tensors.") met = met.detach() self.log( - f"{extra_metric}_{mode}", + f"{key}_{mode}", met, on_step=False, on_epoch=True, @@ -469,9 +476,7 @@ def __init__( module: BaseModuleClass, *, optimizer: Literal["Adam", "AdamW", "Custom"] = "Adam", - optimizer_creator: Optional[ - Callable[[Iterable[torch.Tensor]], torch.optim.Optimizer] - ] = None, + optimizer_creator: Optional[TorchOptimizerCreator] = None, lr: float = 1e-3, weight_decay: float = 1e-6, n_steps_kl_warmup: Union[int, None] = None, @@ -702,15 +707,15 @@ def training_step(self, batch, batch_idx, optimizer_idx=0): labelled_tensors=labelled_dataset, ) input_kwargs.update(self.loss_kwargs) - _, _, scvi_losses = self.forward(full_dataset, loss_kwargs=input_kwargs) - loss = scvi_losses.loss + _, _, loss_output = self.forward(full_dataset, loss_kwargs=input_kwargs) + loss = loss_output.loss self.log( "train_loss", loss, on_epoch=True, - batch_size=len(scvi_losses.reconstruction_loss), + batch_size=loss_output.n_obs_minibatch, ) - self.compute_and_log_metrics(scvi_losses, self.train_metrics, "train") + self.compute_and_log_metrics(loss_output, self.train_metrics, "train") return loss def validation_step(self, batch, batch_idx, optimizer_idx=0): @@ -728,15 +733,15 @@ def validation_step(self, batch, batch_idx, optimizer_idx=0): labelled_tensors=labelled_dataset, ) input_kwargs.update(self.loss_kwargs) - _, _, scvi_losses = self.forward(full_dataset, loss_kwargs=input_kwargs) - loss = scvi_losses.loss + _, _, loss_output = self.forward(full_dataset, loss_kwargs=input_kwargs) + loss = loss_output.loss self.log( "validation_loss", loss, on_epoch=True, - batch_size=len(scvi_losses.reconstruction_loss), + batch_size=loss_output.n_obs_minibatch, ) - self.compute_and_log_metrics(scvi_losses, self.val_metrics, "validation") + self.compute_and_log_metrics(loss_output, self.val_metrics, "validation") class PyroTrainingPlan(pl.LightningModule): @@ -1022,7 +1027,7 @@ def __init__( module: JaxBaseModuleClass, *, optimizer: Literal["Adam", "AdamW", "Custom"] = "Adam", - optimizer_creator: Optional[Callable[[], optax.GradientTransformation]] = None, + optimizer_creator: Optional[JaxOptimizerCreator] = None, lr: float = 1e-3, weight_decay: float = 1e-6, eps: float = 0.01, @@ -1046,7 +1051,7 @@ def __init__( self.automatic_optimization = False self._dummy_param = torch.nn.Parameter(torch.Tensor([0.0])) - def get_optimizer_creator(self) -> Callable[[], optax.GradientTransformation]: + def get_optimizer_creator(self) -> JaxOptimizerCreator: """Get optimizer creator for the model.""" clip_by = ( optax.clip_by_global_norm(self.max_norm) @@ -1101,43 +1106,40 @@ def loss_fn(params): outputs, new_model_state = state.apply_fn( vars_in, batch, rngs=rngs, mutable=list(state.state.keys()), **kwargs ) - loss_recorder = outputs[2] - loss = loss_recorder.loss - elbo = jnp.mean(loss_recorder.reconstruction_loss + loss_recorder.kl_local) - return loss, (elbo, new_model_state) + loss_output = outputs[2] + loss = loss_output.loss + return loss, (loss_output, new_model_state) - (loss, (elbo, new_model_state)), grads = jax.value_and_grad( + (loss, (loss_output, new_model_state)), grads = jax.value_and_grad( loss_fn, has_aux=True )(state.params) new_state = state.apply_gradients(grads=grads, state=new_model_state) - return new_state, loss, elbo + return new_state, loss, loss_output def training_step(self, batch, batch_idx): """Training step for Jax.""" if "kl_weight" in self.loss_kwargs: self.loss_kwargs.update({"kl_weight": self.kl_weight}) self.module.train() - self.module.train_state, loss, elbo = self.jit_training_step( + self.module.train_state, _, loss_output = self.jit_training_step( self.module.train_state, batch, self.module.rngs, loss_kwargs=self.loss_kwargs, ) - loss = torch.tensor(jax.device_get(loss)) - elbo = torch.tensor(jax.device_get(elbo)) + loss_output = jax.tree_util.tree_map( + lambda x: torch.tensor(jax.device_get(x)), + loss_output, + ) # TODO: Better way to get batch size self.log( "train_loss", - loss, - on_epoch=True, - batch_size=batch[REGISTRY_KEYS.X_KEY].shape[0], - ) - self.log( - "elbo_train", - elbo, + loss_output.loss, on_epoch=True, - batch_size=batch[REGISTRY_KEYS.X_KEY].shape[0], + batch_size=loss_output.n_obs_minibatch, + prog_bar=True, ) + self.compute_and_log_metrics(loss_output, self.train_metrics, "train") @partial(jax.jit, static_argnums=(0,)) def jit_validation_step( @@ -1150,35 +1152,30 @@ def jit_validation_step( """Jit validation step.""" vars_in = {"params": state.params, **state.state} outputs = self.module.apply(vars_in, batch, rngs=rngs, **kwargs) - loss_recorder = outputs[2] - loss = loss_recorder.loss - elbo = jnp.mean(loss_recorder.reconstruction_loss + loss_recorder.kl_local) + loss_output = outputs[2] - return loss, elbo + return loss_output def validation_step(self, batch, batch_idx): """Validation step for Jax.""" self.module.eval() - loss, elbo = self.jit_validation_step( + loss_output = self.jit_validation_step( self.module.train_state, batch, self.module.rngs, loss_kwargs=self.loss_kwargs, ) - loss = torch.tensor(jax.device_get(loss)) - elbo = torch.tensor(jax.device_get(elbo)) - self.log( - "validation_loss", - loss, - on_epoch=True, - batch_size=batch[REGISTRY_KEYS.X_KEY].shape[0], + loss_output = jax.tree_util.tree_map( + lambda x: torch.tensor(jax.device_get(x)), + loss_output, ) self.log( - "elbo_validation", - elbo, + "validation_loss", + loss_output.loss, on_epoch=True, - batch_size=batch[REGISTRY_KEYS.X_KEY].shape[0], + batch_size=loss_output.n_obs_minibatch, ) + self.compute_and_log_metrics(loss_output, self.val_metrics, "validation") @staticmethod def transfer_batch_to_device(batch, device, dataloader_idx):