Skip to content

Commit

Permalink
Merge branch 'master' into feature/fid64
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Mar 21, 2023
2 parents 813b980 + b613c50 commit 3e5f4ea
Show file tree
Hide file tree
Showing 31 changed files with 1,505 additions and 181 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[#1621](https://github.com/Lightning-AI/metrics/pull/1621),
[#1624](https://github.com/Lightning-AI/metrics/pull/1624),
[#1623](https://github.com/Lightning-AI/metrics/pull/1623),
[#1638](https://github.com/Lightning-AI/metrics/pull/1638),
[#1631](https://github.com/Lightning-AI/metrics/pull/1631),
)


Expand Down Expand Up @@ -73,6 +75,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added support for plotting of aggregation metrics through `.plot()` method ([#1485](https://github.com/Lightning-AI/metrics/pull/1485))


- Added `ModifiedPanopticQuality` metric to detection package ([#1627](https://github.com/Lightning-AI/metrics/pull/1627))


### Changed

- Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370))
Expand Down
23 changes: 23 additions & 0 deletions docs/source/detection/modified_panoptic_quality.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Modified Panoptic Quality
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Detection

#########################
Modified Panoptic Quality
#########################

.. include:: ../links.rst

Module Interface
________________

.. autoclass:: torchmetrics.ModifiedPanopticQuality
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.modified_panoptic_quality
:noindex:
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,4 @@
.. _Minkowski Distance: https://en.wikipedia.org/wiki/Minkowski_distance
.. _Demographic parity: http://www.fairmlbook.org/
.. _Equal opportunity: https://proceedings.neurips.cc/paper/2016/hash/9d2682367c3935defcb1f9e247a97c0d-Abstract.html
.. _Seamless Scene Segmentation paper: https://arxiv.org/abs/1905.01220
66 changes: 25 additions & 41 deletions docs/source/pages/lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,29 +45,31 @@ The example below shows how to use a metric in your `LightningModule <https://py
self.log('train_acc_step', self.accuracy)
...

def training_epoch_end(self, outs):
def on_train_epoch_end(self):
# log epoch metric
self.log('train_acc_epoch', self.accuracy)

Metric logging in Lightning happens through the ``self.log`` or ``self.log_dict`` method. Both methods only support the logging of *scalar-tensors*.
While the vast majority of metrics in torchmetrics returns a scalar tensor, some metrics such as :class:`~torchmetrics.ConfusionMatrix`, :class:`~torchmetrics.ROC`,
:class:`~torchmetrics.MeanAveragePrecision`, :class:`~torchmetrics.ROUGEScore` return outputs that are non-scalar tensors (often dicts or list of tensors) and should therefore be
dealt with separately. For info about the return type and shape please look at the documentation for the ``compute`` method for each metric you want to log.
Metric logging in Lightning happens through the ``self.log`` or ``self.log_dict`` method. Both methods only support the
logging of *scalar-tensors*. While the vast majority of metrics in torchmetrics returns a scalar tensor, some metrics
such as :class:`~torchmetrics.ConfusionMatrix`, :class:`~torchmetrics.ROC`, :class:`~torchmetrics.MeanAveragePrecision`,
:class:`~torchmetrics.ROUGEScore` return outputs that are non-scalar tensors (often dicts or list of tensors) and should
therefore be dealt with separately. For info about the return type and shape please look at the documentation for the
``compute`` method for each metric you want to log.

********************
Logging TorchMetrics
********************

Logging metrics can be done in two ways: either logging the metric object directly or the computed metric values. When :class:`~torchmetrics.Metric` objects, which return a scalar tensor
are logged directly in Lightning using the LightningModule `self.log <https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#logging-from-a-lightningmodule>`_ method,
Lightning will log the metric based on ``on_step`` and ``on_epoch`` flags present in ``self.log(...)``. If ``on_epoch`` is True, the logger automatically logs the end of epoch metric
value by calling ``.compute()``.
Logging metrics can be done in two ways: either logging the metric object directly or the computed metric values.
When :class:`~torchmetrics.Metric` objects, which return a scalar tensor are logged directly in Lightning using the
LightningModule `self.log <https://pytorch-lightning.readthedocs.io/en/stable/extensions/logging.html#logging-from-a-lightningmodule>`_
method, Lightning will log the metric based on ``on_step`` and ``on_epoch`` flags present in ``self.log(...)``. If
``on_epoch`` is True, the logger automatically logs the end of epoch metric value by calling ``.compute()``.

.. note::

``sync_dist``, ``sync_dist_op``, ``sync_dist_group``, ``reduce_fx`` and ``tbptt_reduce_fx``
flags from ``self.log(...)`` don't affect the metric logging in any manner. The metric class
contains its own distributed synchronization logic.
``sync_dist``, ``sync_dist_group`` and ``reduce_fx`` flags from ``self.log(...)`` don't affect the metric logging
in any manner. The metric class contains its own distributed synchronization logic.

This however is only true for metrics that inherit the base class ``Metric``,
and thus the functional metric API provides no support for in-built distributed synchronization
Expand Down Expand Up @@ -96,8 +98,8 @@ value by calling ``.compute()``.
self.valid_acc(logits, y)
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)

As an alternative to logging the metric object and letting Lightning take care of when to reset the metric etc. you can also manually log the output
of the metrics.
As an alternative to logging the metric object and letting Lightning take care of when to reset the metric etc. you can
also manually log the output of the metrics.

.. testcode:: python

Expand All @@ -115,27 +117,28 @@ of the metrics.
batch_value = self.train_acc(preds, y)
self.log('train_acc_step', batch_value)

def training_epoch_end(self, outputs):
def on_train_epoch_end(self):
self.train_acc.reset()

def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_acc.update(logits, y)

def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self, outputs):
self.log('valid_acc_epoch', self.valid_acc.compute())
self.valid_acc.reset()

Note that logging metrics this way will require you to manually reset the metrics at the end of the epoch yourself. In general, we recommend logging
the metric object to make sure that metrics are correctly computed and reset. Additionally, we highly recommend that the two ways of logging are not
mixed as it can lead to wrong results.
Note that logging metrics this way will require you to manually reset the metrics at the end of the epoch yourself.
In general, we recommend logging the metric object to make sure that metrics are correctly computed and reset.
Additionally, we highly recommend that the two ways of logging are not mixed as it can lead to wrong results.

.. note::

When using any Modular metric, calling ``self.metric(...)`` or ``self.metric.forward(...)`` serves the dual purpose of calling ``self.metric.update()``
on its input and simultaneously returning the metric value over the provided input. So if you are logging a metric *only* on epoch-level (as in the
example above), it is recommended to call ``self.metric.update()`` directly to avoid the extra computation.
When using any Modular metric, calling ``self.metric(...)`` or ``self.metric.forward(...)`` serves the dual purpose
of calling ``self.metric.update()`` on its input and simultaneously returning the metric value over the provided
input. So if you are logging a metric *only* on epoch-level (as in the example above), it is recommended to call
``self.metric.update()`` directly to avoid the extra computation.

.. testcode:: python

Expand All @@ -158,25 +161,6 @@ Common Pitfalls

The following contains a list of pitfalls to be aware of:

* If using metrics in data parallel mode (dp), the metric update/logging should be done
in the ``<mode>_step_end`` method (where ``<mode>`` is either ``training``, ``validation``
or ``test``). This is because ``dp`` split the batches during the forward pass and metric states are destroyed after each forward pass, thus leading to wrong accumulation. In practice do the following:

.. testcode:: python

class MyModule(LightningModule):

def training_step(self, batch, batch_idx):
data, target = batch
preds = self(data)
# ...
return {'loss': loss, 'preds': preds, 'target': target}

def training_step_end(self, outputs):
# update and log
self.metric(outputs['preds'], outputs['target'])
self.log('metric', self.metric)

* Modular metrics contain internal states that should belong to only one DataLoader. In case you are using multiple DataLoaders,
it is recommended to initialize a separate modular metric instances for each DataLoader and use them separately. The same holds
for using seperate metrics for training, validation and testing.
Expand Down
26 changes: 0 additions & 26 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -130,32 +130,6 @@ the native `MetricCollection`_ module can also be used to wrap multiple metrics.

You can always check which device the metric is located on using the `.device` property.

Metrics in Dataparallel (DP) mode
=================================

When using metrics in `Dataparallel (DP) <https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html#torch.nn.DataParallel>`_
mode, one should be aware DP will both create and clean-up replicas of Metric objects during a single forward pass.
This has the consequence, that the metric state of the replicas will as default be destroyed before we can sync
them. It is therefore recommended, when using metrics in DP mode, to initialize them with ``dist_sync_on_step=True``
such that metric states are synchonized between the main process and the replicas before they are destroyed.

Addtionally, if metrics are used together with a `LightningModule` the metric update/logging should be done
in the ``<mode>_step_end`` method (where ``<mode>`` is either ``training``, ``validation`` or ``test``), else
it will lead to wrong accumulation. In practice do the following:

.. testcode::

def training_step(self, batch, batch_idx):
data, target = batch
preds = self(data)
...
return {'loss': loss, 'preds': preds, 'target': target}

def training_step_end(self, outputs):
#update and log
self.metric(outputs['preds'], outputs['target'])
self.log('metric', self.metric)

Metrics in Distributed Data Parallel (DDP) mode
===============================================

Expand Down
2 changes: 1 addition & 1 deletion requirements/text_test.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

jiwer >=2.3.0, <=2.5.2
jiwer >=2.3.0, <=3.0.0
rouge-score >0.1.0, <=0.1.2
bert_score ==0.3.13
transformers >4.4.0, <4.26.2
Expand Down
2 changes: 1 addition & 1 deletion requirements/typing.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mypy==1.0.1
mypy==1.1.1

types-PyYAML
types-emoji
Expand Down
3 changes: 2 additions & 1 deletion src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
StatScores,
)
from torchmetrics.collections import MetricCollection # noqa: E402
from torchmetrics.detection import PanopticQuality # noqa: E402
from torchmetrics.detection import ModifiedPanopticQuality, PanopticQuality # noqa: E402
from torchmetrics.image import ( # noqa: E402
ErrorRelativeGlobalDimensionlessSynthesis,
MultiScaleStructuralSimilarityIndexMeasure,
Expand Down Expand Up @@ -152,6 +152,7 @@
"MetricTracker",
"MinMaxMetric",
"MinMetric",
"ModifiedPanopticQuality",
"MultioutputWrapper",
"MultiScaleStructuralSimilarityIndexMeasure",
"PanopticQuality",
Expand Down
48 changes: 47 additions & 1 deletion src/torchmetrics/classification/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Tuple, no_type_check
from typing import Any, Callable, Optional, Sequence, Tuple, Union, no_type_check

import torch
from torch import Tensor
Expand All @@ -21,6 +21,11 @@
from torchmetrics.functional.classification.stat_scores import _stat_scores_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["Dice.plot"]


class Dice(Metric):
Expand Down Expand Up @@ -235,3 +240,44 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, _, fn = self._get_final_stats()
return _dice_compute(tp, fp, fn, self.average, self.mdmc_reduce, self.zero_division)

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure object and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> from torch import randint
>>> from torchmetrics.classification import Dice
>>> metric = Dice()
>>> metric.update(randint(2,(10,)), randint(2,(10,)))
>>> fig_, ax_ = metric.plot()
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> from torch import randint
>>> from torchmetrics.classification import Dice
>>> metric = Dice()
>>> values = [ ]
>>> for _ in range(10):
... values.append(metric(randint(2,(10,)), randint(2,(10,))))
>>> fig_, ax_ = metric.plot(values)
"""
return self._plot(val, ax)
Loading

0 comments on commit 3e5f4ea

Please sign in to comment.