From d9b477f74e35797b0de1b2c551cd8da4451ad080 Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Wed, 8 Sep 2021 18:37:47 -0400 Subject: [PATCH 01/37] Implement MultioutputWrapper --- tests/wrappers/test_multioutput.py | 27 +++++++ torchmetrics/wrappers/multioutput.py | 106 +++++++++++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 tests/wrappers/test_multioutput.py create mode 100644 torchmetrics/wrappers/multioutput.py diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py new file mode 100644 index 00000000000..2429153ab61 --- /dev/null +++ b/tests/wrappers/test_multioutput.py @@ -0,0 +1,27 @@ +import torch + +from torchmetrics.wrappers.multioutput import MultioutputWrapper +from torchmetrics.classification import Accuracy +from torchmetrics.regression import R2Score + + +def test_multioutput_wrapper(): + # Multiple outputs, same shapes + preds1 = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float) + target1 = torch.tensor([[1, 4], [3, 2], [5, 6]], dtype=torch.float) + preds2 = torch.tensor([[7, 8], [9, 10], [11, 12]], dtype=torch.float) + target2 = torch.tensor([[7, 8], [9, 10], [11, 12]], dtype=torch.float) + + r2 = MultioutputWrapper(R2Score(), num_outputs=2) + r2.update(preds1, target1) + r2.update(preds2, target2) + + # R2 score computed using sklearn's r2_score + torch.testing.assert_allclose(r2.compute(), [1, 0.8857]) + + # Multiple outputs, different shapes + acc = MultioutputWrapper(Accuracy(num_classes=3), num_outputs=2) + preds = torch.tensor([[[0.1, 0.3], [0.8, 0.3], [0.1, 0.4]], [[0.8, 0.3], [0.1, 0.4], [0.1, 0.3]]]) + target = torch.tensor([[1, 2], [1, 1]]) + acc.update(preds, target) + torch.testing.assert_allclose(acc.compute(), [0.5, 1.0]) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py new file mode 100644 index 00000000000..33aa83778f7 --- /dev/null +++ b/torchmetrics/wrappers/multioutput.py @@ -0,0 +1,106 @@ +from copy import deepcopy +from typing import Any, Callable, List, Optional + +import torch +from torch import nn +from torchmetrics import Metric +from torchmetrics.utilities import apply_to_collection + + +def _get_nan_indices(*tensors: torch.Tensor, aligned_dim: int = 0) -> torch.Tensor: + if len(tensors) == 0: + raise ValueError("Must pass at least one tensor as argument") + sentinel = tensors[0] + nan_idxs = torch.zeros(len(sentinel), dtype=torch.bool) + for tensor in tensors: + assert ( + tensor.shape[aligned_dim] == sentinel.shape[aligned_dim] + ), f"{tensor.shape} != {sentinel.shape} (dim: {aligned_dim})" + permuted_tensor = tensor.movedim(aligned_dim, 0).flatten(start_dim=1) + nan_idxs |= torch.any(permuted_tensor.isnan(), dim=1) + return nan_idxs + + +class MultioutputWrapper(Metric): + """ + Wrap a base metric to enable it to support multiple outputs. + + Several torchmetrics metrics, such as `SpearmanCorroef` lack support for multioutput mode. This class wraps such + metrics to support computing one metric per output. Unlike specific torchmetric metrics, it doesn't support any + aggregation across outputs. This means if you set `num_outputs` to 2, `compute()` will return a Tensor of dimension + (2, ...) where ... represents the dimensions the metric returns when not wrapped. + + In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude + fashion, dealing with missing labels (or other data). When `remove_nans` is passed, the class will remove the + intersection of NaN containing "rows" upon each update for each output. For example, suppose a user uses + `MultioutputWrapper` to wrap `R2Score` with 2 outputs, one of which occasionally has missing labels for classes like + `R2Score` is that this class supports removing NaN values (parameter `remove_nans`) on a per-output basis. When + `remove_nans` is passed the wrapper will remove all rows + + Parameters + ---------- + base_metric: Metric, + Metric being wrapped. + num_outputs: int = 1, + Expected dimensionality of the output dimension. This parameter is + used to determine the number of distinct metrics we need to track. + output_dim: int = -1, + Dimension on which output is expected. Note that while this provides some flexibility, the output dimension + must be the same for all inputs to update. This applies even for metrics such as `Accuracy` where the labels + can have a different number of dimensions than the predictions. This can be worked around if the output dimension + can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs. + remove_nans: bool = True, + Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying + metric. Proper operation requires all tensors passed to update to have dimension `(N, ...)` where N represents + the length of the batch or dataset being passed in. + compute_on_step: bool = True, + Whether to recompute the metric value on each update step. + dist_sync_on_step: bool = False, + Required for distributed training support. See torchmetrics docs for additional details. + dist_sync_fn: Callable = None, + Required for distributed training support. See torchmetrics docs for additional details. + """ + + def __init__( + self, + base_metric: Metric, + num_outputs: int = 1, + output_dim: int = -1, + remove_nans: bool = True, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + dist_sync_fn: Callable = None, + ): + super().__init__(compute_on_step, dist_sync_on_step, dist_sync_fn) + self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_outputs)]) + self.output_dim = output_dim + self.remove_nans = remove_nans + + def update(self, *args: Any, **kwargs: Any) -> None: + """Update each underlying metric with the corresponding output.""" + for i in range(len(self.metrics)): + selected_args = apply_to_collection( + args, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i) + ) + selected_kwargs = apply_to_collection( + kwargs, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i) + ) + if self.remove_nans: + args_kwargs = selected_args + tuple(selected_kwargs.values()) + nan_idxs = _get_nan_indices(*args_kwargs) + selected_args = [arg[~nan_idxs] for arg in selected_args] + selected_kwargs = {k: v[~nan_idxs] for k, v in selected_kwargs.items()} + self.metrics[i].update(*selected_args, **selected_kwargs) + + def compute(self) -> torch.Tensor: + """Compute metrics.""" + return torch.stack([m.compute() for m in self.metrics], dim=0) + + @property + def is_differentiable(self) -> bool: + False + + def reset(self): + """Reset all underlying metrics.""" + for metric in self.metrics: + metric.reset() From d6f3f15f7585a18a12ee3c4bb1c631586829385b Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Wed, 8 Sep 2021 18:46:46 -0400 Subject: [PATCH 02/37] Format docstrings properly --- torchmetrics/wrappers/multioutput.py | 43 ++++++++++++++-------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 33aa83778f7..fa267384e85 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -37,28 +37,27 @@ class MultioutputWrapper(Metric): `R2Score` is that this class supports removing NaN values (parameter `remove_nans`) on a per-output basis. When `remove_nans` is passed the wrapper will remove all rows - Parameters - ---------- - base_metric: Metric, - Metric being wrapped. - num_outputs: int = 1, - Expected dimensionality of the output dimension. This parameter is - used to determine the number of distinct metrics we need to track. - output_dim: int = -1, - Dimension on which output is expected. Note that while this provides some flexibility, the output dimension - must be the same for all inputs to update. This applies even for metrics such as `Accuracy` where the labels - can have a different number of dimensions than the predictions. This can be worked around if the output dimension - can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs. - remove_nans: bool = True, - Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying - metric. Proper operation requires all tensors passed to update to have dimension `(N, ...)` where N represents - the length of the batch or dataset being passed in. - compute_on_step: bool = True, - Whether to recompute the metric value on each update step. - dist_sync_on_step: bool = False, - Required for distributed training support. See torchmetrics docs for additional details. - dist_sync_fn: Callable = None, - Required for distributed training support. See torchmetrics docs for additional details. + Args: + base_metric: Metric, + Metric being wrapped. + num_outputs: int = 1, + Expected dimensionality of the output dimension. This parameter is + used to determine the number of distinct metrics we need to track. + output_dim: int = -1, + Dimension on which output is expected. Note that while this provides some flexibility, the output dimension + must be the same for all inputs to update. This applies even for metrics such as `Accuracy` where the labels + can have a different number of dimensions than the predictions. This can be worked around if the output dimension + can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs. + remove_nans: bool = True, + Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying + metric. Proper operation requires all tensors passed to update to have dimension `(N, ...)` where N represents + the length of the batch or dataset being passed in. + compute_on_step: bool = True, + Whether to recompute the metric value on each update step. + dist_sync_on_step: bool = False, + Required for distributed training support. See torchmetrics docs for additional details. + dist_sync_fn: Callable = None, + Required for distributed training support. See torchmetrics docs for additional details. """ def __init__( From bf57891a9c35b115e62736e2a490183cc6a057be Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Wed, 8 Sep 2021 21:01:39 -0400 Subject: [PATCH 03/37] Update docs & exports --- docs/source/references/modules.rst | 6 ++++++ torchmetrics/__init__.py | 2 +- torchmetrics/wrappers/__init__.py | 1 + torchmetrics/wrappers/multioutput.py | 31 ++++++++++++++++++---------- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index ca891c6afd5..a92d477f56c 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -578,3 +578,9 @@ MetricTracker .. autoclass:: torchmetrics.MetricTracker :noindex: + +MultioutputWrapper +~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.MultioutputWrapper + :noindex: diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 5f7798a56e6..9432910c7d7 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -64,7 +64,7 @@ RetrievalRecall, ) from torchmetrics.text import WER, BERTScore, BLEUScore, ROUGEScore # noqa: E402 -from torchmetrics.wrappers import BootStrapper, MetricTracker # noqa: E402 +from torchmetrics.wrappers import BootStrapper, MetricTracker, MultioutputWrapper # noqa: E402 __all__ = [ "functional", diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py index 1655a0bac3c..e008f53124f 100644 --- a/torchmetrics/wrappers/__init__.py +++ b/torchmetrics/wrappers/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401 +from torchmetrics.wrappers.multioutput import MultioutputWrapper # noqa: F401 diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index fa267384e85..e756dbddb61 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Optional import torch from torch import nn @@ -25,17 +25,18 @@ class MultioutputWrapper(Metric): """ Wrap a base metric to enable it to support multiple outputs. - Several torchmetrics metrics, such as `SpearmanCorroef` lack support for multioutput mode. This class wraps such - metrics to support computing one metric per output. Unlike specific torchmetric metrics, it doesn't support any + Several torchmetrics metrics, such as :class:`torchmetrics.regression.spearman.SpearmanCorrcoef` lack support for + multioutput mode. This class wraps such metrics to support computing one metric per output. + Unlike specific torchmetric metrics, it doesn't support any aggregation across outputs. This means if you set `num_outputs` to 2, `compute()` will return a Tensor of dimension (2, ...) where ... represents the dimensions the metric returns when not wrapped. In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude fashion, dealing with missing labels (or other data). When `remove_nans` is passed, the class will remove the intersection of NaN containing "rows" upon each update for each output. For example, suppose a user uses - `MultioutputWrapper` to wrap `R2Score` with 2 outputs, one of which occasionally has missing labels for classes like - `R2Score` is that this class supports removing NaN values (parameter `remove_nans`) on a per-output basis. When - `remove_nans` is passed the wrapper will remove all rows + `MultioutputWrapper` to wrap :class:`torchmetrics.regression.r2.R2Score` with 2 outputs, one of which occasionally + has missing labels for classes like `R2Score` is that this class supports removing NaN values + (parameter `remove_nans`) on a per-output basis. When `remove_nans` is passed the wrapper will remove all rows Args: base_metric: Metric, @@ -46,16 +47,18 @@ class MultioutputWrapper(Metric): output_dim: int = -1, Dimension on which output is expected. Note that while this provides some flexibility, the output dimension must be the same for all inputs to update. This applies even for metrics such as `Accuracy` where the labels - can have a different number of dimensions than the predictions. This can be worked around if the output dimension - can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs. + can have a different number of dimensions than the predictions. This can be worked around if the output + dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs. remove_nans: bool = True, Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying - metric. Proper operation requires all tensors passed to update to have dimension `(N, ...)` where N represents - the length of the batch or dataset being passed in. + metric. Proper operation requires all tensors passed to update to have dimension `(N, ...)` where N + represents the length of the batch or dataset being passed in. compute_on_step: bool = True, Whether to recompute the metric value on each update step. dist_sync_on_step: bool = False, Required for distributed training support. See torchmetrics docs for additional details. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) dist_sync_fn: Callable = None, Required for distributed training support. See torchmetrics docs for additional details. """ @@ -68,9 +71,15 @@ def __init__( remove_nans: bool = True, compute_on_step: bool = True, dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__(compute_on_step, dist_sync_on_step, dist_sync_fn) + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_outputs)]) self.output_dim = output_dim self.remove_nans = remove_nans From d98f91833eb9dd8e9becf78db5c62b010f56d706 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Sep 2021 01:03:40 +0000 Subject: [PATCH 04/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_multioutput.py | 2 +- torchmetrics/wrappers/__init__.py | 2 +- torchmetrics/wrappers/multioutput.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 2429153ab61..82f510688ee 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -1,8 +1,8 @@ import torch -from torchmetrics.wrappers.multioutput import MultioutputWrapper from torchmetrics.classification import Accuracy from torchmetrics.regression import R2Score +from torchmetrics.wrappers.multioutput import MultioutputWrapper def test_multioutput_wrapper(): diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py index e008f53124f..5bca8460c89 100644 --- a/torchmetrics/wrappers/__init__.py +++ b/torchmetrics/wrappers/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 -from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401 from torchmetrics.wrappers.multioutput import MultioutputWrapper # noqa: F401 +from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401 diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index e756dbddb61..743989cc160 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -3,6 +3,7 @@ import torch from torch import nn + from torchmetrics import Metric from torchmetrics.utilities import apply_to_collection @@ -22,8 +23,7 @@ def _get_nan_indices(*tensors: torch.Tensor, aligned_dim: int = 0) -> torch.Tens class MultioutputWrapper(Metric): - """ - Wrap a base metric to enable it to support multiple outputs. + """Wrap a base metric to enable it to support multiple outputs. Several torchmetrics metrics, such as :class:`torchmetrics.regression.spearman.SpearmanCorrcoef` lack support for multioutput mode. This class wraps such metrics to support computing one metric per output. From 373859ef7215ff97d4c6220385ceae0d237169dc Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Wed, 8 Sep 2021 21:20:04 -0400 Subject: [PATCH 05/37] Add wrapper to __all__ --- torchmetrics/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 9432910c7d7..4d701c4e0d1 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -103,6 +103,7 @@ "Metric", "MetricCollection", "MetricTracker", + "MultioutputWrapper", "PearsonCorrcoef", "PIT", "Precision", From 65977841782497298a87d06405ceb7402c2cbf1b Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Wed, 8 Sep 2021 21:26:47 -0400 Subject: [PATCH 06/37] Address deepsource flagged issues --- tests/wrappers/test_multioutput.py | 2 ++ torchmetrics/wrappers/multioutput.py | 10 ++++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 82f510688ee..2ec3227788c 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -6,6 +6,8 @@ def test_multioutput_wrapper(): + """Test that the multioutput wrapper properly slices and computes outputs along the output dimension for both + classification and regression metrics.""" # Multiple outputs, same shapes preds1 = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float) target1 = torch.tensor([[1, 4], [3, 2], [5, 6]], dtype=torch.float) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 743989cc160..ba71a429d7a 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -9,14 +9,12 @@ def _get_nan_indices(*tensors: torch.Tensor, aligned_dim: int = 0) -> torch.Tensor: + """Get indices of rows along `aligned_dim` which have NaN values.""" if len(tensors) == 0: raise ValueError("Must pass at least one tensor as argument") sentinel = tensors[0] nan_idxs = torch.zeros(len(sentinel), dtype=torch.bool) for tensor in tensors: - assert ( - tensor.shape[aligned_dim] == sentinel.shape[aligned_dim] - ), f"{tensor.shape} != {sentinel.shape} (dim: {aligned_dim})" permuted_tensor = tensor.movedim(aligned_dim, 0).flatten(start_dim=1) nan_idxs |= torch.any(permuted_tensor.isnan(), dim=1) return nan_idxs @@ -86,7 +84,7 @@ def __init__( def update(self, *args: Any, **kwargs: Any) -> None: """Update each underlying metric with the corresponding output.""" - for i in range(len(self.metrics)): + for i, metric in enumerate(self.metrics): selected_args = apply_to_collection( args, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i) ) @@ -98,7 +96,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: nan_idxs = _get_nan_indices(*args_kwargs) selected_args = [arg[~nan_idxs] for arg in selected_args] selected_kwargs = {k: v[~nan_idxs] for k, v in selected_kwargs.items()} - self.metrics[i].update(*selected_args, **selected_kwargs) + metric.update(*selected_args, **selected_kwargs) def compute(self) -> torch.Tensor: """Compute metrics.""" @@ -106,7 +104,7 @@ def compute(self) -> torch.Tensor: @property def is_differentiable(self) -> bool: - False + return False def reset(self): """Reset all underlying metrics.""" From 7f713f0154730bfff0e2a45ace000919da34821c Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Thu, 9 Sep 2021 09:27:24 -0400 Subject: [PATCH 07/37] Address PR comments - In order to handle cases like `BinnedAveragePrecision`, add `squeeze_outputs` argument which defaults to true and change `compute()`'s return type to a list. - Add two examples to the docs. --- torchmetrics/wrappers/multioutput.py | 47 +++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index ba71a429d7a..b153d315085 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional import torch from torch import nn @@ -51,7 +51,10 @@ class MultioutputWrapper(Metric): Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying metric. Proper operation requires all tensors passed to update to have dimension `(N, ...)` where N represents the length of the batch or dataset being passed in. - compute_on_step: bool = True, + squeeze_outputs: bool = True + Whether to + + compute_on_step: bool = True Whether to recompute the metric value on each update step. dist_sync_on_step: bool = False, Required for distributed training support. See torchmetrics docs for additional details. @@ -59,14 +62,40 @@ class MultioutputWrapper(Metric): Specify the process group on which synchronization is called. default: None (which selects the entire world) dist_sync_fn: Callable = None, Required for distributed training support. See torchmetrics docs for additional details. + + Example: + + Mimic R2Score in `multioutput`, `raw_values` mode: + >>> import torch + >>> from torchmetrics import MultioutputWrapper, R2Score + >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) + >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) + >>> r2score = MultioutputWrapper(R2Score(), 2) + >>> r2score(preds, target) + [tensor(0.9654), tensor(0.9082)] + + Classification metric where prediction and label tensors have different shapes. + >>> from torchmetrics import BinnedAveragePrecision + >>> target = torch.tensor([[1, 2], [2, 0], [1, 2]]) + >>> preds = torch.tensor([ + ... [[.1, .8], [.8, .05], [.1, .15]], + ... [[.1, .1], [.2, .3], [.7, .6]], + ... [[.002, .4], [.95, .45], [.048, .15]] + ... ]) + >>> binned_avg_precision = MultioutputWrapper(BinnedAveragePrecision(3, thresholds=5), 2) + >>> binned_avg_precision(preds, target) + [[tensor(-0.), tensor(1.0000), tensor(1.0000)], [tensor(0.3333), tensor(-0.), tensor(0.6667)]] + + """ def __init__( self, base_metric: Metric, - num_outputs: int = 1, + num_outputs: int, output_dim: int = -1, remove_nans: bool = True, + squeeze_outputs: bool = True, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -86,21 +115,25 @@ def update(self, *args: Any, **kwargs: Any) -> None: """Update each underlying metric with the corresponding output.""" for i, metric in enumerate(self.metrics): selected_args = apply_to_collection( - args, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i) + args, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device) ) selected_kwargs = apply_to_collection( - kwargs, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i) + kwargs, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device) ) if self.remove_nans: args_kwargs = selected_args + tuple(selected_kwargs.values()) nan_idxs = _get_nan_indices(*args_kwargs) selected_args = [arg[~nan_idxs] for arg in selected_args] selected_kwargs = {k: v[~nan_idxs] for k, v in selected_kwargs.items()} + + if self.squeeze_outputs: + selected_args = [arg.squeeze(self.output_dim) for arg in selected_args] + selected_kwargs = {k: v.squeeze(self.output_dim) for k, v in selected_kwargs.items()} metric.update(*selected_args, **selected_kwargs) - def compute(self) -> torch.Tensor: + def compute(self) -> List[torch.Tensor]: """Compute metrics.""" - return torch.stack([m.compute() for m in self.metrics], dim=0) + return [m.compute() for m in self.metrics] @property def is_differentiable(self) -> bool: From 6219c101bb2ceef1141a31c5dc66c4a53513a35d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Sep 2021 13:31:37 +0000 Subject: [PATCH 08/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/wrappers/multioutput.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index b153d315085..72fe1f63853 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -85,8 +85,6 @@ class MultioutputWrapper(Metric): >>> binned_avg_precision = MultioutputWrapper(BinnedAveragePrecision(3, thresholds=5), 2) >>> binned_avg_precision(preds, target) [[tensor(-0.), tensor(1.0000), tensor(1.0000)], [tensor(0.3333), tensor(-0.), tensor(0.6667)]] - - """ def __init__( From fa02473c8bbc13a9498ed259ca0ac72d415f64fb Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Thu, 9 Sep 2021 09:33:14 -0400 Subject: [PATCH 09/37] Update docs & make squeeze_outputs actually work --- torchmetrics/wrappers/multioutput.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 72fe1f63853..6a51eb598fa 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -39,26 +39,27 @@ class MultioutputWrapper(Metric): Args: base_metric: Metric, Metric being wrapped. - num_outputs: int = 1, + num_outputs: int Expected dimensionality of the output dimension. This parameter is used to determine the number of distinct metrics we need to track. - output_dim: int = -1, + output_dim: int = -1 Dimension on which output is expected. Note that while this provides some flexibility, the output dimension must be the same for all inputs to update. This applies even for metrics such as `Accuracy` where the labels can have a different number of dimensions than the predictions. This can be worked around if the output dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs. - remove_nans: bool = True, + remove_nans: bool = True Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying metric. Proper operation requires all tensors passed to update to have dimension `(N, ...)` where N represents the length of the batch or dataset being passed in. squeeze_outputs: bool = True - Whether to - + If true, will squeeze the 1-item dimensions left after `index_select` is applied. + This is sometimes unnecessary but harmless for metrics such as `R2Score` but useful + for certain classification metrics that can't handle additional 1-item dimensions. compute_on_step: bool = True Whether to recompute the metric value on each update step. - dist_sync_on_step: bool = False, + dist_sync_on_step: bool = False Required for distributed training support. See torchmetrics docs for additional details. - process_group: + process_group: Optional[Any] Specify the process group on which synchronization is called. default: None (which selects the entire world) dist_sync_fn: Callable = None, Required for distributed training support. See torchmetrics docs for additional details. @@ -108,6 +109,7 @@ def __init__( self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_outputs)]) self.output_dim = output_dim self.remove_nans = remove_nans + self.squeeze_outputs = squeeze_outputs def update(self, *args: Any, **kwargs: Any) -> None: """Update each underlying metric with the corresponding output.""" From 5cc7924f4239c1cc3522ac1ef03123e1d175408a Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Thu, 9 Sep 2021 10:18:35 -0400 Subject: [PATCH 10/37] Update tests to be randomized and parametrized --- tests/wrappers/test_multioutput.py | 58 ++++++++++++++++++++---------- 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 2ec3227788c..d1f34980ac1 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -1,3 +1,7 @@ +from functools import partial + +import pytest +from sklearn.metrics import accuracy_score import torch from torchmetrics.classification import Accuracy @@ -5,25 +9,41 @@ from torchmetrics.wrappers.multioutput import MultioutputWrapper -def test_multioutput_wrapper(): - """Test that the multioutput wrapper properly slices and computes outputs along the output dimension for both - classification and regression metrics.""" - # Multiple outputs, same shapes - preds1 = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float) - target1 = torch.tensor([[1, 4], [3, 2], [5, 6]], dtype=torch.float) - preds2 = torch.tensor([[7, 8], [9, 10], [11, 12]], dtype=torch.float) - target2 = torch.tensor([[7, 8], [9, 10], [11, 12]], dtype=torch.float) +def _multioutput_sk_accuracy(preds, target, num_outputs): + accs = [] + for i in range(num_outputs): + accs.append(accuracy_score(torch.argmax(preds[:, :, i], dim=1), target[:, i])) + return accs - r2 = MultioutputWrapper(R2Score(), num_outputs=2) - r2.update(preds1, target1) - r2.update(preds2, target2) - # R2 score computed using sklearn's r2_score - torch.testing.assert_allclose(r2.compute(), [1, 0.8857]) - # Multiple outputs, different shapes - acc = MultioutputWrapper(Accuracy(num_classes=3), num_outputs=2) - preds = torch.tensor([[[0.1, 0.3], [0.8, 0.3], [0.1, 0.4]], [[0.8, 0.3], [0.1, 0.4], [0.1, 0.3]]]) - target = torch.tensor([[1, 2], [1, 1]]) - acc.update(preds, target) - torch.testing.assert_allclose(acc.compute(), [0.5, 1.0]) +@pytest.mark.parametrize( + "metric, compare_metric, pred_generator, target_generator, num_rounds", + [ + ( + MultioutputWrapper(R2Score(), num_outputs=2), + R2Score(num_outputs=2, multioutput="raw_values"), + partial(torch.randn, 10, 2), + partial(torch.randn, 10, 2), + 2, + ), + ( + MultioutputWrapper(Accuracy(num_classes=3), num_outputs=2), + partial(_multioutput_sk_accuracy, num_outputs=2), + partial(torch.rand, 10, 3, 2), + partial(torch.randint, 3, (10, 2)), + 2, + ) + ], +) +def test_multioutput_wrapper(metric, compare_metric, pred_generator, target_generator, num_rounds): + """Test that the multioutput wrapper properly slices and computes outputs along the output dimension for both + classification and regression metrics.""" + preds, targets = [], [] + for _ in range(num_rounds): + preds.append(pred_generator()) + targets.append(target_generator()) + print(preds[-1].shape, targets[-1].shape) + metric.update(preds[-1], targets[-1]) + expected_metric_val = compare_metric(torch.cat(preds), torch.cat(targets)) + torch.testing.assert_allclose(metric.compute(), expected_metric_val) From 424275f7d12869763a8b284e37795d64299b4bf1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Sep 2021 14:19:19 +0000 Subject: [PATCH 11/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_multioutput.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index d1f34980ac1..b4a5635fa1b 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -1,8 +1,8 @@ from functools import partial import pytest -from sklearn.metrics import accuracy_score import torch +from sklearn.metrics import accuracy_score from torchmetrics.classification import Accuracy from torchmetrics.regression import R2Score @@ -16,7 +16,6 @@ def _multioutput_sk_accuracy(preds, target, num_outputs): return accs - @pytest.mark.parametrize( "metric, compare_metric, pred_generator, target_generator, num_rounds", [ @@ -33,7 +32,7 @@ def _multioutput_sk_accuracy(preds, target, num_outputs): partial(torch.rand, 10, 3, 2), partial(torch.randint, 3, (10, 2)), 2, - ) + ), ], ) def test_multioutput_wrapper(metric, compare_metric, pred_generator, target_generator, num_rounds): From 932414a15b2d6c6420675ed412332050926363ab Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Thu, 9 Sep 2021 10:20:54 -0400 Subject: [PATCH 12/37] Fix typechecking error --- torchmetrics/wrappers/multioutput.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 6a51eb598fa..44ecd687d5c 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -139,7 +139,7 @@ def compute(self) -> List[torch.Tensor]: def is_differentiable(self) -> bool: return False - def reset(self): + def reset(self) -> None: """Reset all underlying metrics.""" for metric in self.metrics: metric.reset() From cf9cb6799f66b665877341af7d7de9bca74666d9 Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Fri, 10 Sep 2021 10:36:35 +0200 Subject: [PATCH 13/37] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a23a6128fc2..a5f16c2600c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `average` argument to `AveragePrecision` metric for reducing multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) +- Added `MultioutputWrapper` ([#510](https://github.com/PyTorchLightning/metrics/pull/510)) + + ### Changed - `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477)) From 48c0db911fb805598dbed155b6a41765731614e9 Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Fri, 10 Sep 2021 14:30:39 -0400 Subject: [PATCH 14/37] Fix forward to not reset the states --- tests/wrappers/test_multioutput.py | 126 ++++++++++++++++++++++----- torchmetrics/wrappers/multioutput.py | 26 ++++-- 2 files changed, 126 insertions(+), 26 deletions(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index b4a5635fa1b..94b9137ae74 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -1,48 +1,132 @@ +from collections import namedtuple from functools import partial +from typing import Any, Callable, Optional import pytest import torch from sklearn.metrics import accuracy_score +from sklearn.metrics import r2_score as sk_r2score +from tests.helpers import seed_all +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from torchmetrics import Metric from torchmetrics.classification import Accuracy from torchmetrics.regression import R2Score +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 from torchmetrics.wrappers.multioutput import MultioutputWrapper +seed_all(42) + def _multioutput_sk_accuracy(preds, target, num_outputs): accs = [] + print("sk:", preds[:4], target[:4]) for i in range(num_outputs): accs.append(accuracy_score(torch.argmax(preds[:, :, i], dim=1), target[:, i])) return accs +class _MultioutputMetric(Metric): + """Multi-output version of the R2 score class for testing.""" + + def __init__( + self, + base_metric_class, + num_outputs: int = 1, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Any = None, + dist_sync_fn: Optional[Callable] = None, + **base_metric_kwargs, + ) -> None: + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + self.metric = MultioutputWrapper( + base_metric_class(**base_metric_kwargs), + num_outputs=num_outputs, + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + dist_sync_fn=dist_sync_fn, + ) + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Update the each pair of outputs and predictions.""" + return self.metric.update(preds, target) + + def compute(self) -> torch.Tensor: + """Compute the R2 score between each pair of outputs and predictions.""" + return self.metric.compute() + + @torch.jit.unused + def forward(self, *args, **kwargs): + return self.metric(*args, **kwargs) + + def reset(self) -> None: + """Reset the underlying metric state.""" + self.metric.reset() + + + +num_classes = 3 +num_targets = 2 + +Input = namedtuple("Input", ["preds", "target"]) + +_multi_target_regression_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), +) +_multi_target_classification_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes, num_targets), + target=torch.randint(num_classes, (NUM_BATCHES, BATCH_SIZE, num_targets)), +) + + +def _multi_target_sk_r2score(preds, target, adjusted=0, multioutput="raw_values"): + sk_preds = preds.view(-1, num_targets).numpy() + sk_target = target.view(-1, num_targets).numpy() + r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) + if adjusted != 0: + r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) + return r2_score + @pytest.mark.parametrize( - "metric, compare_metric, pred_generator, target_generator, num_rounds", + "base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs", [ ( - MultioutputWrapper(R2Score(), num_outputs=2), - R2Score(num_outputs=2, multioutput="raw_values"), - partial(torch.randn, 10, 2), - partial(torch.randn, 10, 2), - 2, + R2Score, + _multi_target_sk_r2score, + _multi_target_regression_inputs.preds, + _multi_target_regression_inputs.target, + num_targets, + {}, ), ( - MultioutputWrapper(Accuracy(num_classes=3), num_outputs=2), + Accuracy, partial(_multioutput_sk_accuracy, num_outputs=2), - partial(torch.rand, 10, 3, 2), - partial(torch.randint, 3, (10, 2)), - 2, + _multi_target_classification_inputs.preds, + _multi_target_classification_inputs.target, + num_targets, + dict(num_classes=num_classes), ), ], ) -def test_multioutput_wrapper(metric, compare_metric, pred_generator, target_generator, num_rounds): - """Test that the multioutput wrapper properly slices and computes outputs along the output dimension for both - classification and regression metrics.""" - preds, targets = [], [] - for _ in range(num_rounds): - preds.append(pred_generator()) - targets.append(target_generator()) - print(preds[-1].shape, targets[-1].shape) - metric.update(preds[-1], targets[-1]) - expected_metric_val = compare_metric(torch.cat(preds), torch.cat(targets)) - torch.testing.assert_allclose(metric.compute(), expected_metric_val) +class TestMultioutputWrapper(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_multioutput_wrapper(self, base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs, ddp, dist_sync_on_step): + """Test that the multioutput wrapper properly slices and computes outputs along the output dimension for both + classification and regression metrics.""" + self.run_class_metric_test( + ddp, + preds, + target, + _MultioutputMetric, + compare_metric, + dist_sync_on_step, + metric_args=dict(num_outputs=num_outputs, base_metric_class=base_metric_class, **metric_kwargs), + ) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 44ecd687d5c..6a356e12f16 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -111,9 +111,9 @@ def __init__( self.remove_nans = remove_nans self.squeeze_outputs = squeeze_outputs - def update(self, *args: Any, **kwargs: Any) -> None: - """Update each underlying metric with the corresponding output.""" - for i, metric in enumerate(self.metrics): + def _get_args_kwargs_by_output(self, *args, **kwargs): + args_kwargs_by_output = [] + for i in range(len(self.metrics)): selected_args = apply_to_collection( args, torch.Tensor, torch.index_select, dim=self.output_dim, index=torch.tensor(i, device=self.device) ) @@ -128,13 +128,29 @@ def update(self, *args: Any, **kwargs: Any) -> None: if self.squeeze_outputs: selected_args = [arg.squeeze(self.output_dim) for arg in selected_args] - selected_kwargs = {k: v.squeeze(self.output_dim) for k, v in selected_kwargs.items()} - metric.update(*selected_args, **selected_kwargs) + args_kwargs_by_output.append((selected_args, selected_kwargs)) + return args_kwargs_by_output + + def update(self, *args: Any, **kwargs: Any) -> None: + """Update each underlying metric with the corresponding output.""" + reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs) + for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs): + metric.update(*args, **kwargs) def compute(self) -> List[torch.Tensor]: """Compute metrics.""" return [m.compute() for m in self.metrics] + @torch.jit.unused + def forward(self, *args: Any, **kwargs: Any) -> Any: + results = [] + reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs) + for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs): + results.append(metric(*selected_args, **selected_kwargs)) + if results[0] is not None: + return results + + @property def is_differentiable(self) -> bool: return False From 190eafd78731e417f6252ce69de0b7612b8715c6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Sep 2021 18:31:26 +0000 Subject: [PATCH 15/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_multioutput.py | 10 ++++++---- torchmetrics/wrappers/multioutput.py | 1 - 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 94b9137ae74..0ca7986b551 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -70,7 +70,6 @@ def reset(self) -> None: self.metric.reset() - num_classes = 3 num_targets = 2 @@ -94,6 +93,7 @@ def _multi_target_sk_r2score(preds, target, adjusted=0, multioutput="raw_values" r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) return r2_score + @pytest.mark.parametrize( "base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs", [ @@ -118,9 +118,11 @@ def _multi_target_sk_r2score(preds, target, adjusted=0, multioutput="raw_values" class TestMultioutputWrapper(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_multioutput_wrapper(self, base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs, ddp, dist_sync_on_step): - """Test that the multioutput wrapper properly slices and computes outputs along the output dimension for both - classification and regression metrics.""" + def test_multioutput_wrapper( + self, base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs, ddp, dist_sync_on_step + ): + """Test that the multioutput wrapper properly slices and computes outputs along the output dimension for + both classification and regression metrics.""" self.run_class_metric_test( ddp, preds, diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 6a356e12f16..0aed2baff07 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -150,7 +150,6 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: if results[0] is not None: return results - @property def is_differentiable(self) -> bool: return False From 774458c6e1acc74540b1536abecc1172ee606985 Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Fri, 10 Sep 2021 15:34:04 -0400 Subject: [PATCH 16/37] Fix DeepSource flagged issues --- tests/wrappers/test_multioutput.py | 26 +++++++++++++++----------- torchmetrics/wrappers/multioutput.py | 11 ++++++++++- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 0ca7986b551..c40550dbe2c 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -18,14 +18,6 @@ seed_all(42) -def _multioutput_sk_accuracy(preds, target, num_outputs): - accs = [] - print("sk:", preds[:4], target[:4]) - for i in range(num_outputs): - accs.append(accuracy_score(torch.argmax(preds[:, :, i], dim=1), target[:, i])) - return accs - - class _MultioutputMetric(Metric): """Multi-output version of the R2 score class for testing.""" @@ -76,8 +68,7 @@ def reset(self) -> None: Input = namedtuple("Input", ["preds", "target"]) _multi_target_regression_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) _multi_target_classification_inputs = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes, num_targets), @@ -86,6 +77,7 @@ def reset(self) -> None: def _multi_target_sk_r2score(preds, target, adjusted=0, multioutput="raw_values"): + """Compute R2 score over multiple outputs.""" sk_preds = preds.view(-1, num_targets).numpy() sk_target = target.view(-1, num_targets).numpy() r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) @@ -94,6 +86,17 @@ def _multi_target_sk_r2score(preds, target, adjusted=0, multioutput="raw_values" return r2_score +<<<<<<< HEAD +======= +def _multi_target_sk_accuracy(preds, target, num_outputs): + """Compute accuracy over multiple outputs.""" + accs = [] + for i in range(num_outputs): + accs.append(accuracy_score(torch.argmax(preds[:, :, i], dim=1), target[:, i])) + return accs + + +>>>>>>> Fix DeepSource flagged issues @pytest.mark.parametrize( "base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs", [ @@ -107,7 +110,7 @@ def _multi_target_sk_r2score(preds, target, adjusted=0, multioutput="raw_values" ), ( Accuracy, - partial(_multioutput_sk_accuracy, num_outputs=2), + partial(_multi_target_sk_accuracy, num_outputs=2), _multi_target_classification_inputs.preds, _multi_target_classification_inputs.target, num_targets, @@ -116,6 +119,7 @@ def _multi_target_sk_r2score(preds, target, adjusted=0, multioutput="raw_values" ], ) class TestMultioutputWrapper(MetricTester): + """Test the MultioutputWrapper class with regression and classification inner metrics.""" @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_multioutput_wrapper( diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 0aed2baff07..0f2a4bae26f 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -112,6 +112,7 @@ def __init__( self.squeeze_outputs = squeeze_outputs def _get_args_kwargs_by_output(self, *args, **kwargs): + """Get args and kwargs reshaped to be output-specific and (maybe) having NaNs stripped out.""" args_kwargs_by_output = [] for i in range(len(self.metrics)): selected_args = apply_to_collection( @@ -135,7 +136,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: """Update each underlying metric with the corresponding output.""" reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs) for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs): - metric.update(*args, **kwargs) + metric.update(*selected_args, **selected_kwargs) def compute(self) -> List[torch.Tensor]: """Compute metrics.""" @@ -143,12 +144,20 @@ def compute(self) -> List[torch.Tensor]: @torch.jit.unused def forward(self, *args: Any, **kwargs: Any) -> Any: + """ + Call underlying forward methods and aggregate the results if they're non-null. + + We override this method to ensure that state variables get copied over on the + underlying metrics. + """ results = [] reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs) for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs): results.append(metric(*selected_args, **selected_kwargs)) if results[0] is not None: return results + else: + return None @property def is_differentiable(self) -> bool: From a5d90c11b545ae04d077f084812f47be2ba0d28a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Sep 2021 20:03:51 +0000 Subject: [PATCH 17/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/wrappers/multioutput.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 0f2a4bae26f..41977d1bee7 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -144,11 +144,9 @@ def compute(self) -> List[torch.Tensor]: @torch.jit.unused def forward(self, *args: Any, **kwargs: Any) -> Any: - """ - Call underlying forward methods and aggregate the results if they're non-null. + """Call underlying forward methods and aggregate the results if they're non-null. - We override this method to ensure that state variables get copied over on the - underlying metrics. + We override this method to ensure that state variables get copied over on the underlying metrics. """ results = [] reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs) From 73d97f9e2306d899138823b306beb278abd88342 Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Fri, 10 Sep 2021 16:07:58 -0400 Subject: [PATCH 18/37] Fix leftover conflict code --- tests/wrappers/test_multioutput.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index c40550dbe2c..e944e616e91 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -86,8 +86,6 @@ def _multi_target_sk_r2score(preds, target, adjusted=0, multioutput="raw_values" return r2_score -<<<<<<< HEAD -======= def _multi_target_sk_accuracy(preds, target, num_outputs): """Compute accuracy over multiple outputs.""" accs = [] @@ -96,7 +94,6 @@ def _multi_target_sk_accuracy(preds, target, num_outputs): return accs ->>>>>>> Fix DeepSource flagged issues @pytest.mark.parametrize( "base_metric_class, compare_metric, preds, target, num_outputs, metric_kwargs", [ From f2c3022a640c9c5166744f0d5c58b1ff3edb95b9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 10 Sep 2021 20:09:53 +0000 Subject: [PATCH 19/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_multioutput.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index e944e616e91..7ae94586784 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -68,7 +68,8 @@ def reset(self) -> None: Input = namedtuple("Input", ["preds", "target"]) _multi_target_regression_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) _multi_target_classification_inputs = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes, num_targets), @@ -117,6 +118,7 @@ def _multi_target_sk_accuracy(preds, target, num_outputs): ) class TestMultioutputWrapper(MetricTester): """Test the MultioutputWrapper class with regression and classification inner metrics.""" + @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_multioutput_wrapper( From bf3cd727bbb473a83c6676780095fc775450ddc3 Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Fri, 10 Sep 2021 16:12:47 -0400 Subject: [PATCH 20/37] Remove unused imports --- tests/wrappers/test_multioutput.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 7ae94586784..241ef607e60 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -12,7 +12,6 @@ from torchmetrics import Metric from torchmetrics.classification import Accuracy from torchmetrics.regression import R2Score -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 from torchmetrics.wrappers.multioutput import MultioutputWrapper seed_all(42) From ef59963ce0ff0b1e710e8949dfcfbaff9b5a2dd3 Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Fri, 10 Sep 2021 16:22:30 -0400 Subject: [PATCH 21/37] Fix more deepsource issues --- tests/wrappers/test_multioutput.py | 1 + torchmetrics/wrappers/multioutput.py | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 241ef607e60..134ce5cff36 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -54,6 +54,7 @@ def compute(self) -> torch.Tensor: @torch.jit.unused def forward(self, *args, **kwargs): + """Run forward on the underlying metric.""" return self.metric(*args, **kwargs) def reset(self) -> None: diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 41977d1bee7..286d7531641 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -152,10 +152,9 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: reshaped_args_kwargs = self._get_args_kwargs_by_output(*args, **kwargs) for metric, (selected_args, selected_kwargs) in zip(self.metrics, reshaped_args_kwargs): results.append(metric(*selected_args, **selected_kwargs)) - if results[0] is not None: - return results - else: + if results[0] is None: return None + return results @property def is_differentiable(self) -> bool: From a2beefa38e12193f0fe1084ca7a8d14f67fe002a Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Mon, 13 Sep 2021 09:05:13 -0400 Subject: [PATCH 22/37] Pass distributed args to base_metric_class --- tests/wrappers/test_multioutput.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 134ce5cff36..6686b00c8ec 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -37,7 +37,12 @@ def __init__( dist_sync_fn=dist_sync_fn, ) self.metric = MultioutputWrapper( - base_metric_class(**base_metric_kwargs), + base_metric_class( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + **base_metric_kwargs), num_outputs=num_outputs, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, From 4ce3a6d32b7a2b1e5cf7619058fc04238857bef7 Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Mon, 13 Sep 2021 09:06:12 -0400 Subject: [PATCH 23/37] Fix formatting --- tests/wrappers/test_multioutput.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 6686b00c8ec..3aaa2eada55 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -42,7 +42,8 @@ def __init__( dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, - **base_metric_kwargs), + **base_metric_kwargs, + ), num_outputs=num_outputs, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, @@ -73,8 +74,7 @@ def reset(self) -> None: Input = namedtuple("Input", ["preds", "target"]) _multi_target_regression_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) _multi_target_classification_inputs = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes, num_targets), From 5094e4b23f7e3f08441000afc954093dcc46b91a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Sep 2021 13:08:01 +0000 Subject: [PATCH 24/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_multioutput.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 3aaa2eada55..25bd8214767 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -74,7 +74,8 @@ def reset(self) -> None: Input = namedtuple("Input", ["preds", "target"]) _multi_target_regression_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) _multi_target_classification_inputs = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes, num_targets), From 5fba4239be191e526858f8968bf8771ae75c46cf Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Mon, 13 Sep 2021 09:14:41 -0400 Subject: [PATCH 25/37] Fix typecheck error --- torchmetrics/wrappers/multioutput.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 286d7531641..cc67895b261 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Tuple import torch from torch import nn @@ -111,7 +111,7 @@ def __init__( self.remove_nans = remove_nans self.squeeze_outputs = squeeze_outputs - def _get_args_kwargs_by_output(self, *args, **kwargs): + def _get_args_kwargs_by_output(self, *args: torch.Tensor, **kwargs: torch.Tensor) -> List[Tuple[torch.Tensor, torch.Tensor]]: """Get args and kwargs reshaped to be output-specific and (maybe) having NaNs stripped out.""" args_kwargs_by_output = [] for i in range(len(self.metrics)): From 44add1ff57e4603493e7a2efb1685fc0a2ad6434 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Sep 2021 13:15:22 +0000 Subject: [PATCH 26/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/wrappers/multioutput.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index cc67895b261..ba98ddc850b 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -111,7 +111,9 @@ def __init__( self.remove_nans = remove_nans self.squeeze_outputs = squeeze_outputs - def _get_args_kwargs_by_output(self, *args: torch.Tensor, **kwargs: torch.Tensor) -> List[Tuple[torch.Tensor, torch.Tensor]]: + def _get_args_kwargs_by_output( + self, *args: torch.Tensor, **kwargs: torch.Tensor + ) -> List[Tuple[torch.Tensor, torch.Tensor]]: """Get args and kwargs reshaped to be output-specific and (maybe) having NaNs stripped out.""" args_kwargs_by_output = [] for i in range(len(self.metrics)): From 0e089adbbd0165942e96d4a894fd29a954267fd0 Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Mon, 13 Sep 2021 09:22:49 -0400 Subject: [PATCH 27/37] Use global NUM_CLASSES in test --- tests/wrappers/test_multioutput.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index 25bd8214767..dcb37b8776e 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -8,7 +8,7 @@ from sklearn.metrics import r2_score as sk_r2score from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester +from tests.helpers.testers import BATCH_SIZE, MetricTester, NUM_BATCHES, NUM_CLASSES from torchmetrics import Metric from torchmetrics.classification import Accuracy from torchmetrics.regression import R2Score @@ -68,7 +68,6 @@ def reset(self) -> None: self.metric.reset() -num_classes = 3 num_targets = 2 Input = namedtuple("Input", ["preds", "target"]) @@ -78,8 +77,8 @@ def reset(self) -> None: target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) _multi_target_classification_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes, num_targets), - target=torch.randint(num_classes, (NUM_BATCHES, BATCH_SIZE, num_targets)), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, num_targets), + target=torch.randint(NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE, num_targets)), ) @@ -118,7 +117,7 @@ def _multi_target_sk_accuracy(preds, target, num_outputs): _multi_target_classification_inputs.preds, _multi_target_classification_inputs.target, num_targets, - dict(num_classes=num_classes), + dict(num_classes=NUM_CLASSES), ), ], ) From 7de3a6f65bc32bfca246fe08053f1126b9fe6405 Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Mon, 13 Sep 2021 09:25:04 -0400 Subject: [PATCH 28/37] Address my nitpick comments on PR --- tests/wrappers/test_multioutput.py | 5 ++--- torchmetrics/wrappers/__init__.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index dcb37b8776e..ffa4b728b9f 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -18,7 +18,7 @@ class _MultioutputMetric(Metric): - """Multi-output version of the R2 score class for testing.""" + """Test class that allows passing base metric as a class rather than its instantiation to the wrapper.""" def __init__( self, @@ -73,8 +73,7 @@ def reset(self) -> None: Input = namedtuple("Input", ["preds", "target"]) _multi_target_regression_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), - target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) _multi_target_classification_inputs = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, num_targets), diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py index 5bca8460c89..e008f53124f 100644 --- a/torchmetrics/wrappers/__init__.py +++ b/torchmetrics/wrappers/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 -from torchmetrics.wrappers.multioutput import MultioutputWrapper # noqa: F401 from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401 +from torchmetrics.wrappers.multioutput import MultioutputWrapper # noqa: F401 From 3cabb9b11a441ba3d905e27d886ff03ffe184c8f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Sep 2021 13:26:12 +0000 Subject: [PATCH 29/37] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/wrappers/test_multioutput.py | 5 +++-- torchmetrics/wrappers/__init__.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/wrappers/test_multioutput.py b/tests/wrappers/test_multioutput.py index ffa4b728b9f..35ab90af092 100644 --- a/tests/wrappers/test_multioutput.py +++ b/tests/wrappers/test_multioutput.py @@ -8,7 +8,7 @@ from sklearn.metrics import r2_score as sk_r2score from tests.helpers import seed_all -from tests.helpers.testers import BATCH_SIZE, MetricTester, NUM_BATCHES, NUM_CLASSES +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester from torchmetrics import Metric from torchmetrics.classification import Accuracy from torchmetrics.regression import R2Score @@ -73,7 +73,8 @@ def reset(self) -> None: Input = namedtuple("Input", ["preds", "target"]) _multi_target_regression_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), ) _multi_target_classification_inputs = Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, num_targets), diff --git a/torchmetrics/wrappers/__init__.py b/torchmetrics/wrappers/__init__.py index e008f53124f..5bca8460c89 100644 --- a/torchmetrics/wrappers/__init__.py +++ b/torchmetrics/wrappers/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401 -from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401 from torchmetrics.wrappers.multioutput import MultioutputWrapper # noqa: F401 +from torchmetrics.wrappers.tracker import MetricTracker # noqa: F401 From 7c89fedff09861d67797f91619cc7a51abd197c7 Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Tue, 21 Sep 2021 10:06:39 -0400 Subject: [PATCH 30/37] Remove movedim call for compatibility --- torchmetrics/wrappers/multioutput.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index ba98ddc850b..ea2c05f4d5b 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -8,14 +8,14 @@ from torchmetrics.utilities import apply_to_collection -def _get_nan_indices(*tensors: torch.Tensor, aligned_dim: int = 0) -> torch.Tensor: - """Get indices of rows along `aligned_dim` which have NaN values.""" +def _get_nan_indices(*tensors: torch.Tensor) -> torch.Tensor: + """Get indices of rows along dim 0 which have NaN values.""" if len(tensors) == 0: raise ValueError("Must pass at least one tensor as argument") sentinel = tensors[0] nan_idxs = torch.zeros(len(sentinel), dtype=torch.bool) for tensor in tensors: - permuted_tensor = tensor.movedim(aligned_dim, 0).flatten(start_dim=1) + permuted_tensor = tensor.flatten(start_dim=1) nan_idxs |= torch.any(permuted_tensor.isnan(), dim=1) return nan_idxs From c5b0a3a69ba42479102900beca6f488ae8ec394a Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Tue, 21 Sep 2021 11:23:45 -0400 Subject: [PATCH 31/37] Fix additional isnan incompat error --- torchmetrics/wrappers/multioutput.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index ea2c05f4d5b..4d9d1b9d325 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -16,7 +16,7 @@ def _get_nan_indices(*tensors: torch.Tensor) -> torch.Tensor: nan_idxs = torch.zeros(len(sentinel), dtype=torch.bool) for tensor in tensors: permuted_tensor = tensor.flatten(start_dim=1) - nan_idxs |= torch.any(permuted_tensor.isnan(), dim=1) + nan_idxs |= torch.any(torch.isnan(permuted_tensor), dim=1) return nan_idxs From 49bdda9a162cef11c3f69890d54797825dbe113c Mon Sep 17 00:00:00 2001 From: Stephen Malina Date: Tue, 21 Sep 2021 11:58:25 -0400 Subject: [PATCH 32/37] Fix docs formatting --- torchmetrics/wrappers/multioutput.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 4d9d1b9d325..1e4247da2ed 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -66,7 +66,7 @@ class MultioutputWrapper(Metric): Example: - Mimic R2Score in `multioutput`, `raw_values` mode: + >>> # Mimic R2Score in `multioutput`, `raw_values` mode: >>> import torch >>> from torchmetrics import MultioutputWrapper, R2Score >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) @@ -74,8 +74,7 @@ class MultioutputWrapper(Metric): >>> r2score = MultioutputWrapper(R2Score(), 2) >>> r2score(preds, target) [tensor(0.9654), tensor(0.9082)] - - Classification metric where prediction and label tensors have different shapes. + >>> # Classification metric where prediction and label tensors have different shapes. >>> from torchmetrics import BinnedAveragePrecision >>> target = torch.tensor([[1, 2], [2, 0], [1, 2]]) >>> preds = torch.tensor([ From b0c0782b7bfd233b016fa42bcf96a0806e300663 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 22 Sep 2021 10:39:35 +0200 Subject: [PATCH 33/37] Apply suggestions from code review --- torchmetrics/wrappers/multioutput.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 1e4247da2ed..7c17782d9ea 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -37,32 +37,32 @@ class MultioutputWrapper(Metric): (parameter `remove_nans`) on a per-output basis. When `remove_nans` is passed the wrapper will remove all rows Args: - base_metric: Metric, + base_metric: Metric being wrapped. - num_outputs: int + num_outputs: Expected dimensionality of the output dimension. This parameter is used to determine the number of distinct metrics we need to track. - output_dim: int = -1 + output_dim: Dimension on which output is expected. Note that while this provides some flexibility, the output dimension must be the same for all inputs to update. This applies even for metrics such as `Accuracy` where the labels can have a different number of dimensions than the predictions. This can be worked around if the output dimension can be set to -1 for both, even if -1 corresponds to different dimensions in different inputs. - remove_nans: bool = True + remove_nans: Whether to remove the intersection of rows containing NaNs from the values passed through to each underlying metric. Proper operation requires all tensors passed to update to have dimension `(N, ...)` where N represents the length of the batch or dataset being passed in. - squeeze_outputs: bool = True + squeeze_outputs: If true, will squeeze the 1-item dimensions left after `index_select` is applied. This is sometimes unnecessary but harmless for metrics such as `R2Score` but useful for certain classification metrics that can't handle additional 1-item dimensions. - compute_on_step: bool = True + compute_on_step: Whether to recompute the metric value on each update step. - dist_sync_on_step: bool = False - Required for distributed training support. See torchmetrics docs for additional details. - process_group: Optional[Any] - Specify the process group on which synchronization is called. default: None (which selects the entire world) - dist_sync_fn: Callable = None, - Required for distributed training support. See torchmetrics docs for additional details. + dist_sync_on_step: + Required for distributed training support. + process_group: + Specify the process group on which synchronization is called. The default: None (which selects the entire world) + dist_sync_fn: + Required for distributed training support. Example: From 02309230d20df61e33780b922e675522b0ab872d Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 22 Sep 2021 10:41:01 +0200 Subject: [PATCH 34/37] Apply suggestions from code review --- torchmetrics/wrappers/multioutput.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 7c17782d9ea..0af0a3dcfdc 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -25,8 +25,7 @@ class MultioutputWrapper(Metric): Several torchmetrics metrics, such as :class:`torchmetrics.regression.spearman.SpearmanCorrcoef` lack support for multioutput mode. This class wraps such metrics to support computing one metric per output. - Unlike specific torchmetric metrics, it doesn't support any - aggregation across outputs. This means if you set `num_outputs` to 2, `compute()` will return a Tensor of dimension + Unlike specific torchmetric metrics, it doesn't support any aggregation across outputs. This means if you set `num_outputs` to 2, `compute()` will return a Tensor of dimension (2, ...) where ... represents the dimensions the metric returns when not wrapped. In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude From 462447042a46c246dd5d9b26142c5f99c08a9e8f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 22 Sep 2021 10:42:39 +0200 Subject: [PATCH 35/37] Apply suggestions from code review --- torchmetrics/wrappers/multioutput.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 0af0a3dcfdc..b11969310b5 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -29,11 +29,11 @@ class MultioutputWrapper(Metric): (2, ...) where ... represents the dimensions the metric returns when not wrapped. In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude - fashion, dealing with missing labels (or other data). When `remove_nans` is passed, the class will remove the + fashion, dealing with missing labels (or other data). When ``remove_nans`` is passed, the class will remove the intersection of NaN containing "rows" upon each update for each output. For example, suppose a user uses `MultioutputWrapper` to wrap :class:`torchmetrics.regression.r2.R2Score` with 2 outputs, one of which occasionally - has missing labels for classes like `R2Score` is that this class supports removing NaN values - (parameter `remove_nans`) on a per-output basis. When `remove_nans` is passed the wrapper will remove all rows + has missing labels for classes like ``R2Score`` is that this class supports removing NaN values + (parameter ``remove_nans``) on a per-output basis. When ``remove_nans`` is passed the wrapper will remove all rows Args: base_metric: From f0f1851463989e3f27aa88cbe58851abb612bf0a Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 22 Sep 2021 17:18:21 +0200 Subject: [PATCH 36/37] long --- torchmetrics/wrappers/multioutput.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index b11969310b5..98e51223485 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -25,7 +25,8 @@ class MultioutputWrapper(Metric): Several torchmetrics metrics, such as :class:`torchmetrics.regression.spearman.SpearmanCorrcoef` lack support for multioutput mode. This class wraps such metrics to support computing one metric per output. - Unlike specific torchmetric metrics, it doesn't support any aggregation across outputs. This means if you set `num_outputs` to 2, `compute()` will return a Tensor of dimension + Unlike specific torchmetric metrics, it doesn't support any aggregation across outputs. + This means if you set `num_outputs` to 2, `compute()` will return a Tensor of dimension (2, ...) where ... represents the dimensions the metric returns when not wrapped. In addition to enabling multioutput support for metrics that lack it, this class also supports, albeit in a crude @@ -59,7 +60,8 @@ class MultioutputWrapper(Metric): dist_sync_on_step: Required for distributed training support. process_group: - Specify the process group on which synchronization is called. The default: None (which selects the entire world) + Specify the process group on which synchronization is called. + The default: None (which selects the entire world) dist_sync_fn: Required for distributed training support. From ea45a8f060362065703e977cae09513a23a8ec7e Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 24 Sep 2021 12:16:19 +0200 Subject: [PATCH 37/37] fix device --- torchmetrics/wrappers/multioutput.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/wrappers/multioutput.py b/torchmetrics/wrappers/multioutput.py index 98e51223485..de0b01b642c 100644 --- a/torchmetrics/wrappers/multioutput.py +++ b/torchmetrics/wrappers/multioutput.py @@ -13,7 +13,7 @@ def _get_nan_indices(*tensors: torch.Tensor) -> torch.Tensor: if len(tensors) == 0: raise ValueError("Must pass at least one tensor as argument") sentinel = tensors[0] - nan_idxs = torch.zeros(len(sentinel), dtype=torch.bool) + nan_idxs = torch.zeros(len(sentinel), dtype=torch.bool, device=sentinel.device) for tensor in tensors: permuted_tensor = tensor.flatten(start_dim=1) nan_idxs |= torch.any(torch.isnan(permuted_tensor), dim=1)