From 68489da061581e33405c6112d9e3cfa0cf3507f5 Mon Sep 17 00:00:00 2001 From: Adam Gayoso Date: Mon, 13 Jun 2022 19:01:52 +0100 Subject: [PATCH] clean up metrics code, use MetricCollection (#1529) * clean up metrics code * condense elbo metric code with filtering inside the update fn * codacy * address comments Co-authored-by: Justin Hong --- scvi/train/_metrics.py | 85 +++++++++++++++++++------------ scvi/train/_trainingplans.py | 98 +++++++++++++++++++----------------- 2 files changed, 107 insertions(+), 76 deletions(-) diff --git a/scvi/train/_metrics.py b/scvi/train/_metrics.py index 951e7bdfa2..2cda7ebb3a 100644 --- a/scvi/train/_metrics.py +++ b/scvi/train/_metrics.py @@ -10,32 +10,38 @@ class ElboMetric(Metric): Parameters ---------- - n_obs_total - Number of total observations, for rescaling the ELBO + name + Name of metric, used as the prefix of the logged name. mode - Train or validation, used for logging names + Train or validation, used as the suffix of the logged name. + interval + The interval over which the metric is computed. If "obs", the metric value + per observation is computed. If "batch", the metric value per batch is computed. dist_sync_on_step - optional, by default False + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. **kwargs Keyword args for :class:`torchmetrics.Metric` """ + _N_OBS_MINIBATCH_KEY = "n_obs_minibatch" + def __init__( self, - n_obs_total: int, + name: str, mode: Literal["train", "validation"], + interval: Literal["obs", "batch"], dist_sync_on_step: bool = False, - **kwargs + **kwargs, ): super().__init__(dist_sync_on_step=dist_sync_on_step, **kwargs) - self.n_obs_total = 1 if n_obs_total is None else n_obs_total + self._name = name self._mode = mode + self._interval = interval default_val = torch.tensor(0.0) - self.add_state("reconstruction_loss", default=default_val) - self.add_state("kl_local", default=default_val) - self.add_state("kl_global", default=default_val) + self.add_state("elbo_component", default=default_val) self.add_state("n_obs", default=default_val) self.add_state("n_batches", default=default_val) @@ -43,31 +49,48 @@ def __init__( def mode(self): return self._mode + @property + def name(self): + return f"{self._name}_{self.mode}" + + @name.setter + def name(self, new_name): + self._name = new_name + + @property + def interval(self): + return self._interval + + def get_intervals_recorded(self): + if self.interval == "obs": + return self.n_obs + elif self.interval == "batch": + return self.n_batches + raise ValueError(f"Unrecognized interval: {self.interval}.") + def update( self, - reconstruction_loss_sum: torch.Tensor, - kl_local_sum: torch.Tensor, - kl_global: torch.Tensor, - n_obs_minibatch: int, + **kwargs, ): - """Updates all metrics.""" - reconstruction_loss_sum = reconstruction_loss_sum.detach() - kl_local_sum = kl_local_sum.detach() - kl_global = kl_global.detach() - - self.reconstruction_loss += reconstruction_loss_sum - self.kl_local += kl_local_sum - self.kl_global += kl_global + """ + Updates this metric for one minibatch. + + Takes kwargs associated with all metrics being updated for a given minibatch. + Filters for the relevant metric's value and updates this metric. + """ + if self._N_OBS_MINIBATCH_KEY not in kwargs: + raise ValueError( + f"Missing {self._N_OBS_MINIBATCH_KEY} value in metrics update." + ) + if self._name not in kwargs: + raise ValueError(f"Missing {self._name} value in metrics update.") + + elbo_component = kwargs[self._name].detach() + self.elbo_component += elbo_component + + n_obs_minibatch = kwargs[self._N_OBS_MINIBATCH_KEY] self.n_obs += n_obs_minibatch self.n_batches += 1 def compute(self): - avg_reconstruction_loss = self.reconstruction_loss / self.n_obs - avg_kl_local = self.kl_local / self.n_obs - avg_kl_global = self.kl_global / self.n_batches - # elbo on the scale of one observation - elbo = ( - avg_reconstruction_loss + avg_kl_local + (avg_kl_global / self.n_obs_total) - ) - - return elbo + return self.elbo_component / self.get_intervals_recorded() diff --git a/scvi/train/_trainingplans.py b/scvi/train/_trainingplans.py index a4db8a9f29..f525b864cd 100644 --- a/scvi/train/_trainingplans.py +++ b/scvi/train/_trainingplans.py @@ -6,6 +6,7 @@ import torch from pyro.nn import PyroModule from torch.optim.lr_scheduler import ReduceLROnPlateau +from torchmetrics import MetricCollection from scvi import REGISTRY_KEYS from scvi._compat import Literal @@ -134,14 +135,46 @@ def __init__( self.initialize_train_metrics() self.initialize_val_metrics() + @staticmethod + def _create_elbo_metric_components(mode: str, n_total: Optional[int] = None): + """Initialize ELBO metric and the metric collection.""" + rec_loss = ElboMetric("reconstruction_loss", mode, "obs") + kl_local = ElboMetric("kl_local", mode, "obs") + kl_global = ElboMetric("kl_global", mode, "obs") + # n_total can be 0 if there is no validation set, this won't ever be used + # in that case anyway + n = 1 if n_total is None or n_total < 1 else n_total + elbo = rec_loss + kl_local + (1 / n) * kl_global + elbo.name = f"elbo_{mode}" + collection = MetricCollection( + {metric.name: metric for metric in [elbo, rec_loss, kl_local, kl_global]} + ) + return elbo, rec_loss, kl_local, kl_global, collection + def initialize_train_metrics(self): """Initialize train related metrics.""" - self.elbo_train = ElboMetric(self.n_obs_training, mode="train") + ( + self.elbo_train, + self.rec_loss_train, + self.kl_local_train, + self.kl_global_train, + self.train_metrics, + ) = self._create_elbo_metric_components( + mode="train", n_total=self.n_obs_training + ) self.elbo_train.reset() def initialize_val_metrics(self): - """Initialize train related metrics.""" - self.elbo_val = ElboMetric(self.n_obs_validation, mode="validation") + """Initialize val related metrics.""" + ( + self.elbo_val, + self.rec_loss_val, + self.kl_local_val, + self.kl_global_val, + self.val_metrics, + ) = self._create_elbo_metric_components( + mode="validation", n_total=self.n_obs_validation + ) self.elbo_val.reset() @property @@ -190,7 +223,8 @@ def forward(self, *args, **kwargs): def compute_and_log_metrics( self, loss_recorder: LossRecorder, - elbo_metric: ElboMetric, + metrics: MetricCollection, + mode: str, ): """ Computes and logs metrics. @@ -201,6 +235,9 @@ def compute_and_log_metrics( LossRecorder object from scvi-tools module metric_attr_name The name of the torch metric object to use + mode + Postfix string to add to the metric name of + extra metrics """ rec_loss = loss_recorder.reconstruction_loss n_obs_minibatch = rec_loss.shape[0] @@ -209,44 +246,15 @@ def compute_and_log_metrics( kl_global = loss_recorder.kl_global # use the torchmetric object for the ELBO - elbo_metric( - rec_loss, - kl_local, - kl_global, - n_obs_minibatch, + metrics.update( + reconstruction_loss=rec_loss, + kl_local=kl_local, + kl_global=kl_global, + n_obs_minibatch=n_obs_minibatch, ) - # e.g., train or val mode - mode = elbo_metric.mode # pytorch lightning handles everything with the torchmetric object - self.log( - f"elbo_{mode}", - elbo_metric, - on_step=False, - on_epoch=True, - batch_size=n_obs_minibatch, - ) - - # log elbo components - self.log( - f"reconstruction_loss_{mode}", - rec_loss / elbo_metric.n_obs_total, - reduce_fx=torch.sum, - on_step=False, - on_epoch=True, - batch_size=n_obs_minibatch, - ) - self.log( - f"kl_local_{mode}", - kl_local / elbo_metric.n_obs_total, - reduce_fx=torch.sum, - on_step=False, - on_epoch=True, - batch_size=n_obs_minibatch, - ) - # default aggregation is mean - self.log( - f"kl_global_{mode}", - kl_global, + self.log_dict( + metrics, on_step=False, on_epoch=True, batch_size=n_obs_minibatch, @@ -272,7 +280,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=0): self.loss_kwargs.update({"kl_weight": self.kl_weight}) _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs) self.log("train_loss", scvi_loss.loss, on_epoch=True) - self.compute_and_log_metrics(scvi_loss, self.elbo_train) + self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train") return scvi_loss.loss def validation_step(self, batch, batch_idx): @@ -281,7 +289,7 @@ def validation_step(self, batch, batch_idx): # of training examples _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs) self.log("validation_loss", scvi_loss.loss, on_epoch=True) - self.compute_and_log_metrics(scvi_loss, self.elbo_val) + self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation") def configure_optimizers(self): params = filter(lambda p: p.requires_grad, self.module.parameters()) @@ -452,7 +460,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=0): loss += fool_loss * kappa self.log("train_loss", loss, on_epoch=True) - self.compute_and_log_metrics(scvi_loss, self.elbo_train) + self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train") return loss # train adversarial classifier @@ -602,7 +610,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=0): on_epoch=True, batch_size=len(scvi_losses.reconstruction_loss), ) - self.compute_and_log_metrics(scvi_losses, self.elbo_train) + self.compute_and_log_metrics(scvi_losses, self.train_metrics, "train") return loss def validation_step(self, batch, batch_idx, optimizer_idx=0): @@ -627,7 +635,7 @@ def validation_step(self, batch, batch_idx, optimizer_idx=0): on_epoch=True, batch_size=len(scvi_losses.reconstruction_loss), ) - self.compute_and_log_metrics(scvi_losses, self.elbo_val) + self.compute_and_log_metrics(scvi_losses, self.val_metrics, "validation") class PyroTrainingPlan(pl.LightningModule):