Skip to content

Commit

Permalink
Merge branch 'master' into curve_average
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Oct 2, 2023
2 parents b100ef4 + bd8e556 commit 9a86e05
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
11 changes: 6 additions & 5 deletions docs/source/pages/lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,11 @@ The following contains a list of pitfalls to be aware of:
self.val_acc[dataloader_idx](preds, y)
self.log('val_acc', self.val_acc[dataloader_idx])

* Mixing the two logging methods by calling ``self.log("val", self.metric)`` in ``{training}/{val}/{test}_step`` method and
then calling ``self.log("val", self.metric.compute())`` in the corresponding ``{training}/{val}/{test}_epoch_end`` method.
Because the object is logged in the first case, Lightning will reset the metric before calling the second line leading to
errors or nonsense results.
* Mixing the two logging methods by calling ``self.log("val", self.metric)`` in ``{training|validation|test}_step``
method and then calling ``self.log("val", self.metric.compute())`` in the corresponding
``on_{train|validation|test}_epoch_end`` method.
Because the object is logged in the first case, Lightning will reset the metric before calling the second line leading
to errors or nonsense results.

* Calling ``self.log("val", self.metric(preds, target))`` with the intention of logging the metric object. Because
``self.metric(preds, target)`` corresponds to calling the forward method, this will return a tensor and not the
Expand All @@ -209,4 +210,4 @@ The following contains a list of pitfalls to be aware of:
* Using :class:`~torchmetrics.wrappers.MetricTracker` wrapper with Lightning is a special case, because the wrapper in itself is not a metric
i.e. it does not inherit from the base :class:`~torchmetrics.Metric` class but instead from :class:`~torch.nn.ModuleList`. Thus,
to log the output of this metric one needs to manually log the returned values (not the object) using ``self.log``
and for epoch level logging this should be done in the appropriate ``on_***_epoch_end`` method.
and for epoch level logging this should be done in the appropriate ``on_{train|validation|test}_epoch_end`` method.
2 changes: 1 addition & 1 deletion docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ inside your LightningModule. In most cases we just have to replace ``self.log``
# ...
self.valid_metrics.update(logits, y)

def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self):
# use log_dict instead of log
# metrics are logged with keys: val_Accuracy, val_Precision and val_Recall
output = self.valid_metrics.compute()
Expand Down
2 changes: 1 addition & 1 deletion tests/integrations/lightning/boring_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def training_step(...):
or:
model = BaseTestModel()
model.training_epoch_end = None
model.validation_step = None
"""

Expand Down

0 comments on commit 9a86e05

Please sign in to comment.