Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add _filter_kwargs method to ClasswiseWrapper for better integration with MetricCollection #2575

Merged
merged 9 commits into from
Jun 5, 2024
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `BootstrapWrapper` not being reset correctly ([#2574](https://github.com/Lightning-AI/torchmetrics/pull/2574))


- Fixed integration between `ClasswiseWrapper` and `MetricCollection` with custom `_filter_kwargs` method ([#2575](https://github.com/Lightning-AI/torchmetrics/pull/2575))


## [1.4.0] - 2024-05-03

### Added
Expand Down
11 changes: 8 additions & 3 deletions src/torchmetrics/wrappers/classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,12 @@ def __init__(

self._update_count = 1

def _convert(self, x: Tensor) -> Dict[str, Any]:
def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
"""Filter kwargs for the metric."""
return self.metric._filter_kwargs(**kwargs)

def _convert_output(self, x: Tensor) -> Dict[str, Any]:
"""Convert output to dictionary with labels as keys."""
# Will set the class name as prefix if neither prefix nor postfix is given
if not self._prefix and not self._postfix:
prefix = f"{self.metric.__class__.__name__.lower()}_"
Expand All @@ -156,15 +161,15 @@ def _convert(self, x: Tensor) -> Dict[str, Any]:

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Calculate on batch and accumulate to global state."""
return self._convert(self.metric(*args, **kwargs))
return self._convert_output(self.metric(*args, **kwargs))

def update(self, *args: Any, **kwargs: Any) -> None:
"""Update state."""
self.metric.update(*args, **kwargs)

def compute(self) -> Dict[str, Tensor]:
"""Compute metric."""
return self._convert(self.metric.compute())
return self._convert_output(self.metric.compute())

def reset(self) -> None:
"""Reset metric."""
Expand Down
25 changes: 25 additions & 0 deletions tests/unittests/wrappers/test_classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassRecall
from torchmetrics.clustering import CalinskiHarabaszScore
from torchmetrics.wrappers import ClasswiseWrapper


Expand Down Expand Up @@ -150,3 +151,27 @@ def test_double_use_of_prefix_with_metriccollection():
assert "val/accuracy" in res
assert "val/f_score_Tree" in res
assert "val/f_score_Bush" in res


def test_filter_kwargs_and_metriccollection():
"""Test that kwargs are correctly filtered when using metric collection."""
metric = MetricCollection(
{
"accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)),
"cluster": CalinskiHarabaszScore(),
},
)
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
data = torch.randn(10, 3)

metric.update(preds=preds, target=target, data=data, labels=target)
metric(preds=preds, target=target, data=data, labels=target)
val = metric.compute()

assert isinstance(val, dict)
assert len(val) == 4
assert "multiclassaccuracy_0" in val
assert "multiclassaccuracy_1" in val
assert "multiclassaccuracy_2" in val
assert "cluster" in val
Loading