diff --git a/CHANGELOG.md b/CHANGELOG.md index 7af766d1c87..ca835120400 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `average` argument to `AveragePrecision` metric for reducing multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) +- Added `MultioutputWrapper` ([#510](https://github.com/PyTorchLightning/metrics/pull/510)) + + ### Changed - `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index ca891c6afd5..a92d477f56c 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -578,3 +578,9 @@ MetricTracker .. autoclass:: torchmetrics.MetricTracker :noindex: + +MultioutputWrapper +~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.MultioutputWrapper + :noindex: diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py new file mode 100644 index 00000000000..35ab90af092 --- /dev/null +++ b/tests/wrappers/test_multioutput.py @@ -0,0 +1,142 @@ +from collections import namedtuple +from functools import partial +from typing import Any, Callable, Optional + +import pytest +import torch +from sklearn.metrics import accuracy_score +from sklearn.metrics import r2_score as sk_r2score + +from tests.helpers import seed_all +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester +from torchmetrics import Metric +from torchmetrics.classification import Accuracy +from torchmetrics.regression import R2Score +from torchmetrics.wrappers.multioutput import MultioutputWrapper + +seed_all(42) + + +class _MultioutputMetric(Metric): + """Test class that allows passing base metric as a class rather than its instantiation to the wrapper.""" + + def __init__( + self, + base_metric_class, + num_outputs: int = 1, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Any = None, + dist_sync_fn: Optional[Callable] = None, + **base_metric_kwargs, + ) -> None: + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + self.metric = MultioutputWrapper( + base_metric_class( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + **base_metric_kwargs, + ), + num_outputs=num_outputs, + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + dist_sync_fn=dist_sync_fn, + ) + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Update the each pair of outputs and predictions.""" + return self.metric.update(preds, target) + + def compute(self) -> torch.Tensor: + """Compute the R2 score between each pair of outputs and predictions.""" + return self.metric.compute() + + @torch.jit.unused + def forward(self, *args, **kwargs): + """Run forward on the underlying metric.""" + return self.metric(*args, **kwargs) + + def reset(self) -> None: + """Reset the underlying metric state.""" + self.metric.reset() + + +num_targets = 2 + +Input = namedtuple("Input", ["preds", "target"]) + +_multi_target_regression_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), +) +_multi_target_classification_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, num_targets), + target=torch.randint(NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, num_targets)), +) + + +def _multi_target_sk_r2score(preds, target, adjusted=0, multioutput="raw_values"): + """Compute R2 score over multiple outputs.""" + sk_preds = preds.view(-1, num_targets).numpy() + sk_target = target.view(-1, num_targets).numpy() + r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) + if adjusted != 0: + r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) + return r2_score + + +def _multi_target_sk_accuracy(preds, target, num_outputs): + """Compute accuracy over multiple outputs.""" + accs = [] + for i in range(num_outputs): + accs.append(accuracy_score(torch.argmax(preds[:, :, i], dim=1), target[:, i])) + return accs + + +@pytest.mark.parametrize( + "base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs", + [ + ( + R2Score, + _multi_target_sk_r2score, + _multi_target_regression_inputs.preds, + _multi_target_regression_inputs.target, + num_targets, + {}, + ), + ( + Accuracy, + partial(_multi_target_sk_accuracy, num_outputs=2), + _multi_target_classification_inputs.preds, + _multi_target_classification_inputs.target, + num_targets, + dict(num_classes=NUM_CLASSES), + ), + ], +) +class TestMultioutputWrapper(MetricTester): + """Test the MultioutputWrapper class with regression and classification inner metrics.""" + + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_multioutput_wrapper( + self, base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs, ddp, dist_sync_on_step + ): + """Test that the multioutput wrapper properly slices and computes outputs along the output dimension for + both classification and regression metrics.""" + self.run_class_metric_test( + ddp, + preds, + target, + _MultioutputMetric, + compare_metric, + dist_sync_on_step, + metric_args=dict(num_outputs=num_outputs, base_metric_class=base_metric_class, **metric_kwargs), + ) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 5f7798a56e6..4d701c4e0d1 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -64,7 +64,7 @@ RetrievalRecall, ) from torchmetrics.text import WER, BERTScore, BLEUScore, ROUGEScore # noqa: E402 -from torchmetrics.wrappers import BootStrapper, MetricTracker # noqa: E402 +from torchmetrics.wrappers import BootStrapper, MetricTracker, MultioutputWrapper # noqa: E402 __all__ = [ "functional", @@ -103,6 +103,7 @@ "Metric", "MetricCollection", "MetricTracker", + "MultioutputWrapper", "PearsonCorrcoef", "PIT", "Precision", diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py index 1655a0bac3c..5bca8460c89 100644 --- a/torchmetrics/wrappers/__init__.py +++ b/torchmetrics/wrappers/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 +from torchmetrics.wrappers.multioutput import MultioutputWrapper # noqa: F401 from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401 diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py new file mode 100644 index 00000000000..de0b01b642c --- /dev/null +++ b/torchmetrics/wrappers/multioutput.py @@ -0,0 +1,168 @@ +from copy import deepcopy +from typing import Any, Callable, List, Optional, Tuple + +import torch +from torch import nn + +from torchmetrics import Metric +from torchmetrics.utilities import apply_to_collection + + +def _get_nan_indices(*tensors: torch.Tensor) -> torch.Tensor: + """Get indices of rows along dim 0 which have NaN values.""" + if len(tensors) == 0: + raise ValueError("Must pass at least one tensor as argument") + sentinel = tensors[0] + nan_idxs = torch.zeros(len(sentinel), dtype=torch.bool, device=sentinel.device) + for tensor in tensors: + permuted_tensor = tensor.flatten(start_dim=1) + nan_idxs |= torch.any(torch.isnan(permuted_tensor), dim=1) + return nan_idxs + + +class MultioutputWrapper(Metric): + """Wrap a base metric to enable it to support multiple outputs. + + Several torchmetrics metrics, such as :class:`torchmetrics.regression.spearman.SpearmanCorrcoef` lack support for + multioutput mode. This class wraps such metrics to support computing one metric per output. + Unlike specific torchmetric metrics, it doesn't support any aggregation across outputs. + This means if you set `num_outputs` to 2, `compute()` will return a Tensor of dimension + (2, ...) where ... represents the dimensions the metric returns when not wrapped. + + In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude + fashion, dealing with missing labels (or other data). When ``remove_nans`` is passed, the class will remove the + intersection of NaN containing "rows" upon each update for each output. For example, suppose a user uses + `MultioutputWrapper` to wrap :class:`torchmetrics.regression.r2.R2Score` with 2 outputs, one of which occasionally + has missing labels for classes like ``R2Score`` is that this class supports removing NaN values + (parameter ``remove_nans``) on a per-output basis. When ``remove_nans`` is passed the wrapper will remove all rows + + Args: + base_metric: + Metric being wrapped. + num_outputs: + Expected dimensionality of the output dimension. This parameter is + used to determine the number of distinct metrics we need to track. + output_dim: + Dimension on which output is expected. Note that while this provides some flexibility, the output dimension + must be the same for all inputs to update. This applies even for metrics such as `Accuracy` where the labels + can have a different number of dimensions than the predictions. This can be worked around if the output + dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs. + remove_nans: + Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying + metric. Proper operation requires all tensors passed to update to have dimension `(N, ...)` where N + represents the length of the batch or dataset being passed in. + squeeze_outputs: + If true, will squeeze the 1-item dimensions left after `index_select` is applied. + This is sometimes unnecessary but harmless for metrics such as `R2Score` but useful + for certain classification metrics that can't handle additional 1-item dimensions. + compute_on_step: + Whether to recompute the metric value on each update step. + dist_sync_on_step: + Required for distributed training support. + process_group: + Specify the process group on which synchronization is called. + The default: None (which selects the entire world) + dist_sync_fn: + Required for distributed training support. + + Example: + + >>> # Mimic R2Score in `multioutput`, `raw_values` mode: + >>> import torch + >>> from torchmetrics import MultioutputWrapper, R2Score + >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) + >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) + >>> r2score = MultioutputWrapper(R2Score(), 2) + >>> r2score(preds, target) + [tensor(0.9654), tensor(0.9082)] + >>> # Classification metric where prediction and label tensors have different shapes. + >>> from torchmetrics import BinnedAveragePrecision + >>> target = torch.tensor([[1, 2], [2, 0], [1, 2]]) + >>> preds = torch.tensor([ + ... [[.1, .8], [.8, .05], [.1, .15]], + ... [[.1, .1], [.2, .3], [.7, .6]], + ... [[.002, .4], [.95, .45], [.048, .15]] + ... ]) + >>> binned_avg_precision = MultioutputWrapper(BinnedAveragePrecision(3, thresholds=5), 2) + >>> binned_avg_precision(preds, target) + [[tensor(-0.), tensor(1.0000), tensor(1.0000)], [tensor(0.3333), tensor(-0.), tensor(0.6667)]] + """ + + def __init__( + self, + base_metric: Metric, + num_outputs: int, + output_dim: int = -1, + remove_nans: bool = True, + squeeze_outputs: bool = True, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_outputs)]) + self.output_dim = output_dim + self.remove_nans = remove_nans + self.squeeze_outputs = squeeze_outputs + + def _get_args_kwargs_by_output( + self, *args: torch.Tensor, **kwargs: torch.Tensor + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """Get args and kwargs reshaped to be output-specific and (maybe) having NaNs stripped out.""" + args_kwargs_by_output = [] + for i in range(len(self.metrics)): + selected_args = apply_to_collection( + args, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device) + ) + selected_kwargs = apply_to_collection( + kwargs, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device) + ) + if self.remove_nans: + args_kwargs = selected_args + tuple(selected_kwargs.values()) + nan_idxs = _get_nan_indices(*args_kwargs) + selected_args = [arg[~nan_idxs] for arg in selected_args] + selected_kwargs = {k: v[~nan_idxs] for k, v in selected_kwargs.items()} + + if self.squeeze_outputs: + selected_args = [arg.squeeze(self.output_dim) for arg in selected_args] + args_kwargs_by_output.append((selected_args, selected_kwargs)) + return args_kwargs_by_output + + def update(self, *args: Any, **kwargs: Any) -> None: + """Update each underlying metric with the corresponding output.""" + reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs) + for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs): + metric.update(*selected_args, **selected_kwargs) + + def compute(self) -> List[torch.Tensor]: + """Compute metrics.""" + return [m.compute() for m in self.metrics] + + @torch.jit.unused + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Call underlying forward methods and aggregate the results if they're non-null. + + We override this method to ensure that state variables get copied over on the underlying metrics. + """ + results = [] + reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs) + for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs): + results.append(metric(*selected_args, **selected_kwargs)) + if results[0] is None: + return None + return results + + @property + def is_differentiable(self) -> bool: + return False + + def reset(self) -> None: + """Reset all underlying metrics.""" + for metric in self.metrics: + metric.reset()