Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot method for aggregation + refactor tests #1485

Merged
merged 49 commits into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
512bb82
starting point
SkafteNicki Feb 5, 2023
81f6137
fix aggregation methods
SkafteNicki Feb 5, 2023
3a5478e
update testing
SkafteNicki Feb 5, 2023
ce8fa90
fix confusion matrix
SkafteNicki Feb 6, 2023
99b7cd5
Merge branch 'master' into plot/aggregation
SkafteNicki Feb 6, 2023
4f5dcf1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2023
1cf220b
changelog
SkafteNicki Feb 6, 2023
f3d19cb
Merge branch 'plot/aggregation' of https://github.com/PyTorchLightnin…
SkafteNicki Feb 6, 2023
84ac12e
Apply suggestions from code review
Borda Feb 7, 2023
cec45f6
Merge branch 'master' into plot/aggregation
Borda Feb 7, 2023
22952c7
fix PT
Borda Feb 7, 2023
814c206
Merge branch 'master' into plot/aggregation
SkafteNicki Feb 10, 2023
b973c85
fix doctest
SkafteNicki Feb 10, 2023
70bd493
merge
SkafteNicki Feb 18, 2023
b2b9876
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 20, 2023
29cd999
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 20, 2023
d671db9
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 20, 2023
9a05114
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 20, 2023
42340d8
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 20, 2023
7e3ab80
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 20, 2023
6ca60be
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 20, 2023
491f00d
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 20, 2023
7b61fe7
skip
Borda Feb 20, 2023
f2a48bf
inputs
Borda Feb 20, 2023
02506bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2023
c8caa41
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 20, 2023
377f9f0
type
Borda Feb 20, 2023
35ac117
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 21, 2023
f5e659b
cleaning
Borda Feb 21, 2023
60c98df
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 21, 2023
168a2dc
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 21, 2023
d95aab9
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 21, 2023
f95ad8a
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 21, 2023
e6d281c
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 21, 2023
c0c4e0b
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 22, 2023
ff48709
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 22, 2023
454c4a3
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 22, 2023
28b4342
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 22, 2023
b0519c1
Merge branch 'master' into plot/aggregation
Borda Feb 22, 2023
23e7b22
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 22, 2023
15d54b9
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 22, 2023
39b7bfb
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 22, 2023
ad87fb8
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 22, 2023
a039946
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 23, 2023
155e3ea
Merge branch 'master' into plot/aggregation
SkafteNicki Feb 25, 2023
bec2a21
Merge branch 'master' into plot/aggregation
SkafteNicki Feb 25, 2023
c6bbdd9
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 25, 2023
d7a4403
Merge branch 'master' into plot/aggregation
mergify[bot] Feb 26, 2023
d549188
drop unused plot_options
Borda Feb 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `ClassificationTask` Enum and use in metrics ([#1479](https://github.com/Lightning-AI/metrics/pull/1479))


- Added support for plotting of aggregation metrics through `.plot()` method ([#1485](https://github.com/Lightning-AI/metrics/pull/1485))

### Changed

- Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370))
Expand Down
180 changes: 179 additions & 1 deletion src/torchmetrics/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Callable, List, Union
from typing import Any, Callable, List, Optional, Sequence, Union

import torch
from torch import Tensor

from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["SumMetric.plot", "MeanMetric.plot"]


class BaseAggregator(Metric):
Expand Down Expand Up @@ -153,6 +158,49 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore
if value.numel(): # make sure tensor not empty
self.value = torch.max(self.value, torch.max(value))

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> from torchmetrics import MaxMetric
>>> metric = MaxMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> from torchmetrics import MaxMetric
>>> metric = MaxMetric()
>>> values = [ ]
>>> for i in range(10):
... values.append(metric(i))
>>> fig_, ax_ = metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__
)
return fig, ax


class MinMetric(BaseAggregator):
"""Aggregate a stream of value into their minimum value.
Expand Down Expand Up @@ -214,6 +262,49 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore
if value.numel(): # make sure tensor not empty
self.value = torch.min(self.value, torch.min(value))

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> from torchmetrics import MinMetric
>>> metric = MinMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> from torchmetrics import MinMetric
>>> metric = MinMetric()
>>> values = [ ]
>>> for i in range(10):
... values.append(metric(i))
>>> fig_, ax_ = metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__
)
return fig, ax


class SumMetric(BaseAggregator):
"""Aggregate a stream of value into their sum.
Expand Down Expand Up @@ -273,6 +364,50 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore
if value.numel():
self.value += value.sum()

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> from torchmetrics import SumMetric
>>> metric = SumMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> from torch import rand, randint
>>> from torchmetrics import SumMetric
>>> metric = SumMetric()
>>> values = [ ]
>>> for i in range(10):
... values.append(metric([i, i+1]))
>>> fig_, ax_ = metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__
)
return fig, ax


class CatMetric(BaseAggregator):
"""Concatenate a stream of values.
Expand Down Expand Up @@ -407,3 +542,46 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0
def compute(self) -> Tensor:
"""Compute the aggregated value."""
return self.value / self.weight

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> # Example plotting a single value
>>> from torchmetrics import MeanMetric
>>> metric = MeanMetric()
>>> metric.update([1, 2, 3])
>>> fig_, ax_ = metric.plot()

.. plot::
:scale: 75

>>> # Example plotting multiple values
>>> from torchmetrics import MeanMetric
>>> metric = MeanMetric()
>>> values = [ ]
>>> for i in range(10):
... values.append(metric([i, i+1]))
>>> fig_, ax_ = metric.plot(values)
"""
val = val or self.compute()
fig, ax = plot_single_or_multi_val(
val, ax=ax, higher_is_better=self.higher_is_better, **self.plot_options, name=self.__class__.__name__
)
return fig, ax
82 changes: 78 additions & 4 deletions src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, List, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -40,7 +40,11 @@
from torchmetrics.utilities.plot import _PLOT_OUT_TYPE, plot_confusion_matrix

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["MulticlassConfusionMatrix.plot"]
__doctest_skip__ = [
"BinaryConfusionMatrix.plot",
"MulticlassConfusionMatrix.plot",
"MulticlassConfusionMatrix.plot",
]


class BinaryConfusionMatrix(Metric):
Expand Down Expand Up @@ -126,6 +130,39 @@ def compute(self) -> Tensor:
"""Compute confusion matrix."""
return _binary_confusion_matrix_compute(self.confmat, self.normalize)

def plot(
self, val: Optional[Tensor] = None, add_text: bool = True, labels: Optional[List[str]] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> from torch import randint
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> metric = MulticlassConfusionMatrix(num_classes=5)
>>> metric.update(randint(5, (20,)), randint(5, (20,)))
>>> fig_, ax_ = metric.plot()
"""
val = val or self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val, add_text=add_text, labels=labels)
return fig, ax


class MulticlassConfusionMatrix(Metric):
r"""Compute the `confusion matrix`_ for multiclass tasks.
Expand Down Expand Up @@ -231,12 +268,16 @@ def compute(self) -> Tensor:
"""Compute confusion matrix."""
return _multiclass_confusion_matrix_compute(self.confmat, self.normalize)

def plot(self, val: Optional[Tensor] = None) -> _PLOT_OUT_TYPE:
def plot(
self, val: Optional[Tensor] = None, add_text: bool = True, labels: Optional[List[str]] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes

Returns:
fig: Figure object
Expand All @@ -258,7 +299,7 @@ def plot(self, val: Optional[Tensor] = None) -> _PLOT_OUT_TYPE:
val = val or self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val)
fig, ax = plot_confusion_matrix(val, add_text=add_text, labels=labels)
return fig, ax


Expand Down Expand Up @@ -352,6 +393,39 @@ def compute(self) -> Tensor:
"""Compute confusion matrix."""
return _multilabel_confusion_matrix_compute(self.confmat, self.normalize)

def plot(
self, val: Optional[Tensor] = None, add_text: bool = True, labels: Optional[List[str]] = None
) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.

Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
add_text: if the value of each cell should be added to the plot
labels: a list of strings, if provided will be added to the plot to indicate the different classes

Returns:
Figure and Axes object

Raises:
ModuleNotFoundError:
If `matplotlib` is not installed

.. plot::
:scale: 75

>>> from torch import randint
>>> from torchmetrics.classification import MulticlassConfusionMatrix
>>> metric = MulticlassConfusionMatrix(num_classes=5)
>>> metric.update(randint(5, (20,)), randint(5, (20,)))
>>> fig_, ax_ = metric.plot()
"""
val = val or self.compute()
if not isinstance(val, Tensor):
raise TypeError(f"Expected val to be a single tensor but got {val}")
fig, ax = plot_confusion_matrix(val, add_text=add_text, labels=labels)
return fig, ax


class ConfusionMatrix:
r"""Compute the `confusion matrix`_.
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class Metric(Module, ABC):
is_differentiable: Optional[bool] = None
higher_is_better: Optional[bool] = None
full_state_update: Optional[bool] = None
plot_options: Dict[str, Union[str, float]] = {}

def __init__(
self,
Expand Down
8 changes: 6 additions & 2 deletions src/torchmetrics/utilities/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,18 @@ def plot_confusion_matrix(
"Expected number of elements in arg `labels` to match number of labels in confmat but "
f"got {len(labels)} and {n_classes}"
)
labels: Union[List[int], List[str]] = labels if labels is not None else np.arange(n_classes).tolist()
if confmat.ndim == 3:
fig_label = labels if confmat.ndim == 3 and labels is not None else np.arange(nb)
labels = np.arange(n_classes).tolist()
else:
labels: Union[List[int], List[str]] = labels if labels is not None else np.arange(n_classes).tolist()

fig, axs = plt.subplots(nrows=rows, ncols=cols)
axs = trim_axs(axs, nb)
for i in range(nb):
if rows != 1 and cols != 1:
ax = axs[i]
ax.set_title(f"Label {i}", fontsize=15)
ax.set_title(f"Label {fig_label[i]}", fontsize=15)
else:
ax = axs
ax.imshow(confmat[i].cpu().detach() if confmat.ndim == 3 else confmat.cpu().detach())
Expand Down
Loading