From bd8e5569d7018f1fa535f35e1e24b4ceb92a4a6f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 2 Oct 2023 13:23:39 +0200 Subject: [PATCH] docs: fix references to the `*_epoch_end` (#2118) Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- docs/source/pages/lightning.rst | 11 ++++++----- docs/source/pages/overview.rst | 2 +- tests/integrations/lightning/boring_model.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/source/pages/lightning.rst b/docs/source/pages/lightning.rst index 131251501d1..655c72103b0 100644 --- a/docs/source/pages/lightning.rst +++ b/docs/source/pages/lightning.rst @@ -187,10 +187,11 @@ The following contains a list of pitfalls to be aware of: self.val_acc[dataloader_idx](preds, y) self.log('val_acc', self.val_acc[dataloader_idx]) -* Mixing the two logging methods by calling ``self.log("val", self.metric)`` in ``{training}/{val}/{test}_step`` method and - then calling ``self.log("val", self.metric.compute())`` in the corresponding ``{training}/{val}/{test}_epoch_end`` method. - Because the object is logged in the first case, Lightning will reset the metric before calling the second line leading to - errors or nonsense results. +* Mixing the two logging methods by calling ``self.log("val", self.metric)`` in ``{training|validation|test}_step`` + method and then calling ``self.log("val", self.metric.compute())`` in the corresponding + ``on_{train|validation|test}_epoch_end`` method. + Because the object is logged in the first case, Lightning will reset the metric before calling the second line leading + to errors or nonsense results. * Calling ``self.log("val", self.metric(preds, target))`` with the intention of logging the metric object. Because ``self.metric(preds, target)`` corresponds to calling the forward method, this will return a tensor and not the @@ -209,4 +210,4 @@ The following contains a list of pitfalls to be aware of: * Using :class:`~torchmetrics.wrappers.MetricTracker` wrapper with Lightning is a special case, because the wrapper in itself is not a metric i.e. it does not inherit from the base :class:`~torchmetrics.Metric` class but instead from :class:`~torch.nn.ModuleList`. Thus, to log the output of this metric one needs to manually log the returned values (not the object) using ``self.log`` - and for epoch level logging this should be done in the appropriate ``on_***_epoch_end`` method. + and for epoch level logging this should be done in the appropriate ``on_{train|validation|test}_epoch_end`` method. diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index edf1a2d7162..5fff9455f81 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -331,7 +331,7 @@ inside your LightningModule. In most cases we just have to replace ``self.log`` # ... self.valid_metrics.update(logits, y) - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): # use log_dict instead of log # metrics are logged with keys: val_Accuracy, val_Precision and val_Recall output = self.valid_metrics.compute() diff --git a/tests/integrations/lightning/boring_model.py b/tests/integrations/lightning/boring_model.py index 4f1409fd7a8..991e27aa273 100644 --- a/tests/integrations/lightning/boring_model.py +++ b/tests/integrations/lightning/boring_model.py @@ -68,7 +68,7 @@ def training_step(...): or: model = BaseTestModel() - model.training_epoch_end = None + model.validation_step = None """