-
Notifications
You must be signed in to change notification settings - Fork 413
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Better support for classwise logging (#832)
* implementation * collection * tests * docs Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
- Loading branch information
1 parent
9daa5e2
commit 6131d82
Showing
9 changed files
with
182 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import pytest | ||
import torch | ||
|
||
from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall | ||
|
||
|
||
def test_raises_error_on_wrong_input(): | ||
"""Test that errors are raised on wrong input.""" | ||
with pytest.raises(ValueError, match="Expected argument `metric` to be an instance of `torchmetrics.Metric` but.*"): | ||
ClasswiseWrapper([]) | ||
|
||
with pytest.raises(ValueError, match="Expected argument `labels` to either be `None` or a list of strings.*"): | ||
ClasswiseWrapper(Accuracy(), "hest") | ||
|
||
|
||
def test_output_no_labels(): | ||
"""Test that wrapper works with no label input.""" | ||
metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None)) | ||
preds = torch.randn(10, 3).softmax(dim=-1) | ||
target = torch.randint(3, (10,)) | ||
val = metric(preds, target) | ||
assert isinstance(val, dict) | ||
assert len(val) == 3 | ||
for i in range(3): | ||
assert f"accuracy_{i}" in val | ||
|
||
|
||
def test_output_with_labels(): | ||
"""Test that wrapper works with label input.""" | ||
labels = ["horse", "fish", "cat"] | ||
metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels) | ||
preds = torch.randn(10, 3).softmax(dim=-1) | ||
target = torch.randint(3, (10,)) | ||
val = metric(preds, target) | ||
assert isinstance(val, dict) | ||
assert len(val) == 3 | ||
for lab in labels: | ||
assert f"accuracy_{lab}" in val | ||
|
||
|
||
def test_using_metriccollection(): | ||
"""Test wrapper in combination with metric collection.""" | ||
labels = ["horse", "fish", "cat"] | ||
metric = MetricCollection( | ||
{ | ||
"accuracy": ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels), | ||
"recall": ClasswiseWrapper(Recall(num_classes=3, average=None), labels=labels), | ||
} | ||
) | ||
preds = torch.randn(10, 3).softmax(dim=-1) | ||
target = torch.randint(3, (10,)) | ||
val = metric(preds, target) | ||
assert isinstance(val, dict) | ||
assert len(val) == 6 | ||
for lab in labels: | ||
assert f"accuracy_{lab}" in val | ||
assert f"recall_{lab}" in val |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from typing import Dict, List, Optional, Union | ||
|
||
from torch import Tensor | ||
|
||
from torchmetrics import Metric | ||
|
||
|
||
class ClasswiseWrapper(Metric): | ||
"""Wrapper class for altering the output of classification metrics that returns multiple values to include | ||
label information. | ||
Args: | ||
metric: base metric that should be wrapped. It is assumed that the metric outputs a single | ||
tensor that is split along the first dimension. | ||
class_labels: list of strings indicating the different classes. | ||
Example: | ||
>>> import torch | ||
>>> _ = torch.manual_seed(42) | ||
>>> from torchmetrics import Accuracy, ClasswiseWrapper | ||
>>> metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None)) | ||
>>> preds = torch.randn(10, 3).softmax(dim=-1) | ||
>>> target = torch.randint(3, (10,)) | ||
>>> metric(preds, target) | ||
{'accuracy_0': tensor(0.5000), 'accuracy_1': tensor(0.7500), 'accuracy_2': tensor(0.)} | ||
Example (labels as list of strings): | ||
>>> import torch | ||
>>> from torchmetrics import Accuracy, ClasswiseWrapper | ||
>>> metric = ClasswiseWrapper( | ||
... Accuracy(num_classes=3, average=None), | ||
... labels=["horse", "fish", "dog"] | ||
... ) | ||
>>> preds = torch.randn(10, 3).softmax(dim=-1) | ||
>>> target = torch.randint(3, (10,)) | ||
>>> metric(preds, target) | ||
{'accuracy_horse': tensor(0.3333), 'accuracy_fish': tensor(0.6667), 'accuracy_dog': tensor(0.)} | ||
Example (in metric collection): | ||
>>> import torch | ||
>>> from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall | ||
>>> labels = ["horse", "fish", "dog"] | ||
>>> metric = MetricCollection( | ||
... {'accuracy': ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels), | ||
... 'recall': ClasswiseWrapper(Recall(num_classes=3, average=None), labels)} | ||
... ) | ||
>>> preds = torch.randn(10, 3).softmax(dim=-1) | ||
>>> target = torch.randint(3, (10,)) | ||
>>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE | ||
{'accuracy_horse': tensor(0.), 'accuracy_fish': tensor(0.3333), 'accuracy_dog': tensor(0.4000), | ||
'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)} | ||
""" | ||
|
||
def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None: | ||
super().__init__() | ||
if not isinstance(metric, Metric): | ||
raise ValueError(f"Expected argument `metric` to be an instance of `torchmetrics.Metric` but got {metric}") | ||
if labels is not None and not (isinstance(labels, list) and all(isinstance(lab, str) for lab in labels)): | ||
raise ValueError(f"Expected argument `labels` to either be `None` or a list of strings but got {labels}") | ||
self.metric = metric | ||
self.labels = labels | ||
|
||
def _convert(self, x: Tensor) -> Dict[Union[str, int], float]: | ||
name = self.metric.__class__.__name__.lower() | ||
if self.labels is None: | ||
return {f"{name}_{i}": val for i, val in enumerate(x)} | ||
return {f"{name}_{lab}": val for lab, val in zip(self.labels, x)} | ||
|
||
def update(self, *args, **kwargs) -> None: | ||
self.metric.update(*args, **kwargs) | ||
|
||
def compute(self) -> Dict[str, Tensor]: | ||
return self._convert(self.metric.compute()) |