Skip to content

Commit

Permalink
clean up metrics code, use MetricCollection (#1529)
Browse files Browse the repository at this point in the history
* clean up metrics code

* condense elbo metric code with filtering inside the update fn

* codacy

* address comments

Co-authored-by: Justin Hong <[email protected]>
  • Loading branch information
adamgayoso and justjhong authored Jun 13, 2022
1 parent 73952af commit aecf24d
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 76 deletions.
85 changes: 54 additions & 31 deletions scvi/train/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,64 +10,87 @@ class ElboMetric(Metric):
Parameters
----------
n_obs_total
Number of total observations, for rescaling the ELBO
name
Name of metric, used as the prefix of the logged name.
mode
Train or validation, used for logging names
Train or validation, used as the suffix of the logged name.
interval
The interval over which the metric is computed. If "obs", the metric value
per observation is computed. If "batch", the metric value per batch is computed.
dist_sync_on_step
optional, by default False
Synchronize metric state across processes at each ``forward()``
before returning the value at the step.
**kwargs
Keyword args for :class:`torchmetrics.Metric`
"""

_N_OBS_MINIBATCH_KEY = "n_obs_minibatch"

def __init__(
self,
n_obs_total: int,
name: str,
mode: Literal["train", "validation"],
interval: Literal["obs", "batch"],
dist_sync_on_step: bool = False,
**kwargs
**kwargs,
):
super().__init__(dist_sync_on_step=dist_sync_on_step, **kwargs)

self.n_obs_total = 1 if n_obs_total is None else n_obs_total
self._name = name
self._mode = mode
self._interval = interval

default_val = torch.tensor(0.0)
self.add_state("reconstruction_loss", default=default_val)
self.add_state("kl_local", default=default_val)
self.add_state("kl_global", default=default_val)
self.add_state("elbo_component", default=default_val)
self.add_state("n_obs", default=default_val)
self.add_state("n_batches", default=default_val)

@property
def mode(self):
return self._mode

@property
def name(self):
return f"{self._name}_{self.mode}"

@name.setter
def name(self, new_name):
self._name = new_name

@property
def interval(self):
return self._interval

def get_intervals_recorded(self):
if self.interval == "obs":
return self.n_obs
elif self.interval == "batch":
return self.n_batches
raise ValueError(f"Unrecognized interval: {self.interval}.")

def update(
self,
reconstruction_loss_sum: torch.Tensor,
kl_local_sum: torch.Tensor,
kl_global: torch.Tensor,
n_obs_minibatch: int,
**kwargs,
):
"""Updates all metrics."""
reconstruction_loss_sum = reconstruction_loss_sum.detach()
kl_local_sum = kl_local_sum.detach()
kl_global = kl_global.detach()

self.reconstruction_loss += reconstruction_loss_sum
self.kl_local += kl_local_sum
self.kl_global += kl_global
"""
Updates this metric for one minibatch.
Takes kwargs associated with all metrics being updated for a given minibatch.
Filters for the relevant metric's value and updates this metric.
"""
if self._N_OBS_MINIBATCH_KEY not in kwargs:
raise ValueError(
f"Missing {self._N_OBS_MINIBATCH_KEY} value in metrics update."
)
if self._name not in kwargs:
raise ValueError(f"Missing {self._name} value in metrics update.")

elbo_component = kwargs[self._name].detach()
self.elbo_component += elbo_component

n_obs_minibatch = kwargs[self._N_OBS_MINIBATCH_KEY]
self.n_obs += n_obs_minibatch
self.n_batches += 1

def compute(self):
avg_reconstruction_loss = self.reconstruction_loss / self.n_obs
avg_kl_local = self.kl_local / self.n_obs
avg_kl_global = self.kl_global / self.n_batches
# elbo on the scale of one observation
elbo = (
avg_reconstruction_loss + avg_kl_local + (avg_kl_global / self.n_obs_total)
)

return elbo
return self.elbo_component / self.get_intervals_recorded()
98 changes: 53 additions & 45 deletions scvi/train/_trainingplans.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from pyro.nn import PyroModule
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import MetricCollection

from scvi import REGISTRY_KEYS
from scvi._compat import Literal
Expand Down Expand Up @@ -134,14 +135,46 @@ def __init__(
self.initialize_train_metrics()
self.initialize_val_metrics()

@staticmethod
def _create_elbo_metric_components(mode: str, n_total: Optional[int] = None):
"""Initialize ELBO metric and the metric collection."""
rec_loss = ElboMetric("reconstruction_loss", mode, "obs")
kl_local = ElboMetric("kl_local", mode, "obs")
kl_global = ElboMetric("kl_global", mode, "obs")
# n_total can be 0 if there is no validation set, this won't ever be used
# in that case anyway
n = 1 if n_total is None or n_total < 1 else n_total
elbo = rec_loss + kl_local + (1 / n) * kl_global
elbo.name = f"elbo_{mode}"
collection = MetricCollection(
{metric.name: metric for metric in [elbo, rec_loss, kl_local, kl_global]}
)
return elbo, rec_loss, kl_local, kl_global, collection

def initialize_train_metrics(self):
"""Initialize train related metrics."""
self.elbo_train = ElboMetric(self.n_obs_training, mode="train")
(
self.elbo_train,
self.rec_loss_train,
self.kl_local_train,
self.kl_global_train,
self.train_metrics,
) = self._create_elbo_metric_components(
mode="train", n_total=self.n_obs_training
)
self.elbo_train.reset()

def initialize_val_metrics(self):
"""Initialize train related metrics."""
self.elbo_val = ElboMetric(self.n_obs_validation, mode="validation")
"""Initialize val related metrics."""
(
self.elbo_val,
self.rec_loss_val,
self.kl_local_val,
self.kl_global_val,
self.val_metrics,
) = self._create_elbo_metric_components(
mode="validation", n_total=self.n_obs_validation
)
self.elbo_val.reset()

@property
Expand Down Expand Up @@ -190,7 +223,8 @@ def forward(self, *args, **kwargs):
def compute_and_log_metrics(
self,
loss_recorder: LossRecorder,
elbo_metric: ElboMetric,
metrics: MetricCollection,
mode: str,
):
"""
Computes and logs metrics.
Expand All @@ -201,6 +235,9 @@ def compute_and_log_metrics(
LossRecorder object from scvi-tools module
metric_attr_name
The name of the torch metric object to use
mode
Postfix string to add to the metric name of
extra metrics
"""
rec_loss = loss_recorder.reconstruction_loss
n_obs_minibatch = rec_loss.shape[0]
Expand All @@ -209,44 +246,15 @@ def compute_and_log_metrics(
kl_global = loss_recorder.kl_global

# use the torchmetric object for the ELBO
elbo_metric(
rec_loss,
kl_local,
kl_global,
n_obs_minibatch,
metrics.update(
reconstruction_loss=rec_loss,
kl_local=kl_local,
kl_global=kl_global,
n_obs_minibatch=n_obs_minibatch,
)
# e.g., train or val mode
mode = elbo_metric.mode
# pytorch lightning handles everything with the torchmetric object
self.log(
f"elbo_{mode}",
elbo_metric,
on_step=False,
on_epoch=True,
batch_size=n_obs_minibatch,
)

# log elbo components
self.log(
f"reconstruction_loss_{mode}",
rec_loss / elbo_metric.n_obs_total,
reduce_fx=torch.sum,
on_step=False,
on_epoch=True,
batch_size=n_obs_minibatch,
)
self.log(
f"kl_local_{mode}",
kl_local / elbo_metric.n_obs_total,
reduce_fx=torch.sum,
on_step=False,
on_epoch=True,
batch_size=n_obs_minibatch,
)
# default aggregation is mean
self.log(
f"kl_global_{mode}",
kl_global,
self.log_dict(
metrics,
on_step=False,
on_epoch=True,
batch_size=n_obs_minibatch,
Expand All @@ -272,7 +280,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=0):
self.loss_kwargs.update({"kl_weight": self.kl_weight})
_, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
self.log("train_loss", scvi_loss.loss, on_epoch=True)
self.compute_and_log_metrics(scvi_loss, self.elbo_train)
self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")
return scvi_loss.loss

def validation_step(self, batch, batch_idx):
Expand All @@ -281,7 +289,7 @@ def validation_step(self, batch, batch_idx):
# of training examples
_, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
self.log("validation_loss", scvi_loss.loss, on_epoch=True)
self.compute_and_log_metrics(scvi_loss, self.elbo_val)
self.compute_and_log_metrics(scvi_loss, self.val_metrics, "validation")

def configure_optimizers(self):
params = filter(lambda p: p.requires_grad, self.module.parameters())
Expand Down Expand Up @@ -452,7 +460,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=0):
loss += fool_loss * kappa

self.log("train_loss", loss, on_epoch=True)
self.compute_and_log_metrics(scvi_loss, self.elbo_train)
self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")
return loss

# train adversarial classifier
Expand Down Expand Up @@ -602,7 +610,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=0):
on_epoch=True,
batch_size=len(scvi_losses.reconstruction_loss),
)
self.compute_and_log_metrics(scvi_losses, self.elbo_train)
self.compute_and_log_metrics(scvi_losses, self.train_metrics, "train")
return loss

def validation_step(self, batch, batch_idx, optimizer_idx=0):
Expand All @@ -627,7 +635,7 @@ def validation_step(self, batch, batch_idx, optimizer_idx=0):
on_epoch=True,
batch_size=len(scvi_losses.reconstruction_loss),
)
self.compute_and_log_metrics(scvi_losses, self.elbo_val)
self.compute_and_log_metrics(scvi_losses, self.val_metrics, "validation")


class PyroTrainingPlan(pl.LightningModule):
Expand Down

0 comments on commit aecf24d

Please sign in to comment.