Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MultioutputWrapper #510

Merged
merged 45 commits into from
Sep 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
d9b477f
Implement MultioutputWrapper
an1lam Sep 8, 2021
d6f3f15
Format docstrings properly
an1lam Sep 8, 2021
bf57891
Update docs & exports
an1lam Sep 9, 2021
d98f918
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
373859e
Add wrapper to __all__
an1lam Sep 9, 2021
6597784
Address deepsource flagged issues
an1lam Sep 9, 2021
7f713f0
Address PR comments
an1lam Sep 9, 2021
6219c10
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
fa02473
Update docs & make squeeze_outputs actually work
an1lam Sep 9, 2021
5cc7924
Update tests to be randomized and parametrized
an1lam Sep 9, 2021
424275f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 9, 2021
932414a
Fix typechecking error
an1lam Sep 9, 2021
fddbcaf
Merge branch 'master' into master
SkafteNicki Sep 10, 2021
cf9cb67
changelog
SkafteNicki Sep 10, 2021
48c0db9
Fix forward to not reset the states
an1lam Sep 10, 2021
190eafd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2021
fcf7ce1
Merge branch 'master' into master
an1lam Sep 10, 2021
774458c
Fix DeepSource flagged issues
an1lam Sep 10, 2021
a5d90c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2021
73d97f9
Fix leftover conflict code
an1lam Sep 10, 2021
f2c3022
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 10, 2021
bf3cd72
Remove unused imports
an1lam Sep 10, 2021
ef59963
Fix more deepsource issues
an1lam Sep 10, 2021
a2beefa
Pass distributed args to base_metric_class
an1lam Sep 13, 2021
4ce3a6d
Fix formatting
an1lam Sep 13, 2021
5094e4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2021
5fba423
Fix typecheck error
an1lam Sep 13, 2021
44add1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2021
0e089ad
Use global NUM_CLASSES in test
an1lam Sep 13, 2021
7de3a6f
Address my nitpick comments on PR
an1lam Sep 13, 2021
3cabb9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2021
bafd0d1
Merge branch 'master' into master
an1lam Sep 13, 2021
7c89fed
Remove movedim call for compatibility
an1lam Sep 21, 2021
acad18d
Merge branch 'master' into master
SkafteNicki Sep 21, 2021
c5b0a3a
Fix additional isnan incompat error
an1lam Sep 21, 2021
49bdda9
Fix docs formatting
an1lam Sep 21, 2021
ab0beee
Merge branch 'master' into master
an1lam Sep 21, 2021
8b2fa0b
Merge branch 'master' into master
SkafteNicki Sep 21, 2021
b0c0782
Apply suggestions from code review
Borda Sep 22, 2021
0230923
Apply suggestions from code review
Borda Sep 22, 2021
4624470
Apply suggestions from code review
Borda Sep 22, 2021
86fd51c
Merge branch 'master' into master
mergify[bot] Sep 22, 2021
f0f1851
long
Borda Sep 22, 2021
b39c46f
Merge branch 'master' into master
SkafteNicki Sep 24, 2021
ea45a8f
fix device
SkafteNicki Sep 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `average` argument to `AveragePrecision` metric for reducing multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))


- Added `MultioutputWrapper` ([#510](https://github.com/PyTorchLightning/metrics/pull/510))


### Changed

- `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))
Expand Down
6 changes: 6 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -578,3 +578,9 @@ MetricTracker

.. autoclass:: torchmetrics.MetricTracker
:noindex:

MultioutputWrapper
~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.MultioutputWrapper
:noindex:
142 changes: 142 additions & 0 deletions tests/wrappers/test_multioutput.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from collections import namedtuple
from functools import partial
from typing import Any, Callable, Optional

import pytest
import torch
from sklearn.metrics import accuracy_score
from sklearn.metrics import r2_score as sk_r2score

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester
from torchmetrics import Metric
from torchmetrics.classification import Accuracy
from torchmetrics.regression import R2Score
from torchmetrics.wrappers.multioutput import MultioutputWrapper

seed_all(42)


class _MultioutputMetric(Metric):
"""Test class that allows passing base metric as a class rather than its instantiation to the wrapper."""

def __init__(
self,
base_metric_class,
num_outputs: int = 1,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Any = None,
dist_sync_fn: Optional[Callable] = None,
**base_metric_kwargs,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.metric = MultioutputWrapper(
base_metric_class(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
**base_metric_kwargs,
),
num_outputs=num_outputs,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
dist_sync_fn=dist_sync_fn,
)

def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
"""Update the each pair of outputs and predictions."""
return self.metric.update(preds, target)

def compute(self) -> torch.Tensor:
"""Compute the R2 score between each pair of outputs and predictions."""
return self.metric.compute()

@torch.jit.unused
def forward(self, *args, **kwargs):
"""Run forward on the underlying metric."""
return self.metric(*args, **kwargs)

def reset(self) -> None:
"""Reset the underlying metric state."""
self.metric.reset()


num_targets = 2

Input = namedtuple("Input", ["preds", "target"])

_multi_target_regression_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
)
_multi_target_classification_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, num_targets),
target=torch.randint(NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, num_targets)),
)


def _multi_target_sk_r2score(preds, target, adjusted=0, multioutput="raw_values"):
"""Compute R2 score over multiple outputs."""
sk_preds = preds.view(-1, num_targets).numpy()
sk_target = target.view(-1, num_targets).numpy()
r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput)
if adjusted != 0:
r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1)
return r2_score


def _multi_target_sk_accuracy(preds, target, num_outputs):
"""Compute accuracy over multiple outputs."""
accs = []
for i in range(num_outputs):
accs.append(accuracy_score(torch.argmax(preds[:, :, i], dim=1), target[:, i]))
return accs


@pytest.mark.parametrize(
"base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs",
[
(
R2Score,
_multi_target_sk_r2score,
_multi_target_regression_inputs.preds,
_multi_target_regression_inputs.target,
num_targets,
{},
),
(
Accuracy,
partial(_multi_target_sk_accuracy, num_outputs=2),
_multi_target_classification_inputs.preds,
_multi_target_classification_inputs.target,
num_targets,
dict(num_classes=NUM_CLASSES),
),
],
)
class TestMultioutputWrapper(MetricTester):
"""Test the MultioutputWrapper class with regression and classification inner metrics."""

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_multioutput_wrapper(
self, base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs, ddp, dist_sync_on_step
):
"""Test that the multioutput wrapper properly slices and computes outputs along the output dimension for
both classification and regression metrics."""
self.run_class_metric_test(
ddp,
preds,
target,
_MultioutputMetric,
compare_metric,
dist_sync_on_step,
metric_args=dict(num_outputs=num_outputs, base_metric_class=base_metric_class, **metric_kwargs),
)
3 changes: 2 additions & 1 deletion torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
RetrievalRecall,
)
from torchmetrics.text import WER, BERTScore, BLEUScore, ROUGEScore # noqa: E402
from torchmetrics.wrappers import BootStrapper, MetricTracker # noqa: E402
from torchmetrics.wrappers import BootStrapper, MetricTracker, MultioutputWrapper # noqa: E402

__all__ = [
"functional",
Expand Down Expand Up @@ -103,6 +103,7 @@
"Metric",
"MetricCollection",
"MetricTracker",
"MultioutputWrapper",
"PearsonCorrcoef",
"PIT",
"Precision",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401
from torchmetrics.wrappers.multioutput import MultioutputWrapper # noqa: F401
an1lam marked this conversation as resolved.
Show resolved Hide resolved
from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401
168 changes: 168 additions & 0 deletions torchmetrics/wrappers/multioutput.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from copy import deepcopy
from typing import Any, Callable, List, Optional, Tuple

import torch
from torch import nn

from torchmetrics import Metric
from torchmetrics.utilities import apply_to_collection


def _get_nan_indices(*tensors: torch.Tensor) -> torch.Tensor:
"""Get indices of rows along dim 0 which have NaN values."""
if len(tensors) == 0:
raise ValueError("Must pass at least one tensor as argument")
sentinel = tensors[0]
nan_idxs = torch.zeros(len(sentinel), dtype=torch.bool, device=sentinel.device)
for tensor in tensors:
permuted_tensor = tensor.flatten(start_dim=1)
nan_idxs |= torch.any(torch.isnan(permuted_tensor), dim=1)
return nan_idxs


class MultioutputWrapper(Metric):
"""Wrap a base metric to enable it to support multiple outputs.

Several torchmetrics metrics, such as :class:`torchmetrics.regression.spearman.SpearmanCorrcoef` lack support for
multioutput mode. This class wraps such metrics to support computing one metric per output.
Unlike specific torchmetric metrics, it doesn't support any aggregation across outputs.
This means if you set `num_outputs` to 2, `compute()` will return a Tensor of dimension
(2, ...) where ... represents the dimensions the metric returns when not wrapped.

In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude
fashion, dealing with missing labels (or other data). When ``remove_nans`` is passed, the class will remove the
intersection of NaN containing "rows" upon each update for each output. For example, suppose a user uses
`MultioutputWrapper` to wrap :class:`torchmetrics.regression.r2.R2Score` with 2 outputs, one of which occasionally
has missing labels for classes like ``R2Score`` is that this class supports removing NaN values
(parameter ``remove_nans``) on a per-output basis. When ``remove_nans`` is passed the wrapper will remove all rows

Args:
base_metric:
Metric being wrapped.
num_outputs:
Expected dimensionality of the output dimension. This parameter is
used to determine the number of distinct metrics we need to track.
output_dim:
Dimension on which output is expected. Note that while this provides some flexibility, the output dimension
must be the same for all inputs to update. This applies even for metrics such as `Accuracy` where the labels
can have a different number of dimensions than the predictions. This can be worked around if the output
dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs.
remove_nans:
Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying
metric. Proper operation requires all tensors passed to update to have dimension `(N, ...)` where N
represents the length of the batch or dataset being passed in.
squeeze_outputs:
If true, will squeeze the 1-item dimensions left after `index_select` is applied.
This is sometimes unnecessary but harmless for metrics such as `R2Score` but useful
for certain classification metrics that can't handle additional 1-item dimensions.
compute_on_step:
Whether to recompute the metric value on each update step.
dist_sync_on_step:
Required for distributed training support.
process_group:
Specify the process group on which synchronization is called.
The default: None (which selects the entire world)
dist_sync_fn:
Required for distributed training support.

Example:

>>> # Mimic R2Score in `multioutput`, `raw_values` mode:
>>> import torch
>>> from torchmetrics import MultioutputWrapper, R2Score
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = MultioutputWrapper(R2Score(), 2)
>>> r2score(preds, target)
[tensor(0.9654), tensor(0.9082)]
>>> # Classification metric where prediction and label tensors have different shapes.
>>> from torchmetrics import BinnedAveragePrecision
>>> target = torch.tensor([[1, 2], [2, 0], [1, 2]])
>>> preds = torch.tensor([
... [[.1, .8], [.8, .05], [.1, .15]],
... [[.1, .1], [.2, .3], [.7, .6]],
... [[.002, .4], [.95, .45], [.048, .15]]
... ])
>>> binned_avg_precision = MultioutputWrapper(BinnedAveragePrecision(3, thresholds=5), 2)
>>> binned_avg_precision(preds, target)
[[tensor(-0.), tensor(1.0000), tensor(1.0000)], [tensor(0.3333), tensor(-0.), tensor(0.6667)]]
"""
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
base_metric: Metric,
num_outputs: int,
output_dim: int = -1,
remove_nans: bool = True,
squeeze_outputs: bool = True,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_outputs)])
self.output_dim = output_dim
self.remove_nans = remove_nans
self.squeeze_outputs = squeeze_outputs

def _get_args_kwargs_by_output(
self, *args: torch.Tensor, **kwargs: torch.Tensor
) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""Get args and kwargs reshaped to be output-specific and (maybe) having NaNs stripped out."""
args_kwargs_by_output = []
for i in range(len(self.metrics)):
selected_args = apply_to_collection(
args, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device)
)
selected_kwargs = apply_to_collection(
kwargs, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device)
)
if self.remove_nans:
args_kwargs = selected_args + tuple(selected_kwargs.values())
nan_idxs = _get_nan_indices(*args_kwargs)
selected_args = [arg[~nan_idxs] for arg in selected_args]
selected_kwargs = {k: v[~nan_idxs] for k, v in selected_kwargs.items()}

if self.squeeze_outputs:
selected_args = [arg.squeeze(self.output_dim) for arg in selected_args]
args_kwargs_by_output.append((selected_args, selected_kwargs))
return args_kwargs_by_output

def update(self, *args: Any, **kwargs: Any) -> None:
"""Update each underlying metric with the corresponding output."""
reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs)
for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs):
metric.update(*selected_args, **selected_kwargs)

def compute(self) -> List[torch.Tensor]:
"""Compute metrics."""
return [m.compute() for m in self.metrics]

@torch.jit.unused
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Call underlying forward methods and aggregate the results if they're non-null.

We override this method to ensure that state variables get copied over on the underlying metrics.
"""
results = []
reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs)
for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs):
results.append(metric(*selected_args, **selected_kwargs))
if results[0] is None:
return None
return results

@property
def is_differentiable(self) -> bool:
return False

def reset(self) -> None:
"""Reset all underlying metrics."""
for metric in self.metrics:
metric.reset()