From f45dd630483e2a0c10c79e0fa7f60bae72846755 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Wed, 24 Mar 2021 12:45:53 -0700 Subject: [PATCH 01/33] WIP: Binned PR-related metrics --- .../test_binned_precision_recall.py | 152 ++++++++++++++++++ .../classification/binned_precision_recall.py | 148 +++++++++++++++++ 2 files changed, 300 insertions(+) create mode 100644 tests/classification/test_binned_precision_recall.py create mode 100644 torchmetrics/classification/binned_precision_recall.py diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py new file mode 100644 index 00000000000..619aa0e9f51 --- /dev/null +++ b/tests/classification/test_binned_precision_recall.py @@ -0,0 +1,152 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torchmetrics.functional import precision_recall_curve +from functools import partial +from typing import Tuple + +import pytest +import torch +import torch.multiprocessing as mp +from sklearn.metrics import average_precision_score as _sk_average_precision_score +from torchmetrics.classification.binned_precision_recall import ( + BinnedAveragePrecision, + BinnedRecallAtFixedPrecision, +) +from tests.classification.inputs import ( + Input, +) +from tests.helpers.testers import ( + NUM_CLASSES, + NUM_BATCHES, + BATCH_SIZE, + NUM_PROCESSES, + MetricTester, +) + + +torch.manual_seed(42) + + +def construct_not_terrible_input(): + correct_targets = torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) + preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) + targets = torch.zeros_like(preds, dtype=torch.long) + for i in range(preds.shape[0]): + for j in range(preds.shape[1]): + targets[i, j, correct_targets[i, j]] = 1 + preds += torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) * targets / 3 + + preds = preds / preds.sum(dim=2, keepdim=True) + + return Input(preds=preds, target=targets) + + +__test_input = construct_not_terrible_input() + + +def recall_at_precision_x_multilabel( + precision: torch.Tensor, recall, thresholds: torch.Tensor, min_precision: float +) -> Tuple[float, float]: + try: + max_recall, max_precision, best_threshold = max( + (r, p, t) + for p, r, t in zip(precision, recall, thresholds) + if p >= min_precision + ) + except ValueError: + max_recall, best_threshold = 0, 1e6 + + return max_recall, best_threshold + + +def _multiclass_prob_sk_metric(predictions, targets, num_classes, min_precision): + max_recalls = torch.zeros(num_classes) + best_thresholds = torch.zeros(num_classes) + + for i in range(num_classes): + precisions, recalls, thresholds = precision_recall_curve( + predictions[:, i], targets[:, i], pos_label=1 + ) + max_recalls[i], best_thresholds[i] = recall_at_precision_x_multilabel( + precisions, recalls, thresholds, min_precision + ) + return max_recalls, best_thresholds + + +def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): + return _sk_average_precision_score(targets, predictions, average=None) + + +@pytest.mark.parametrize( + "preds, target, sk_metric, num_classes", + [ + ( + __test_input.preds, + __test_input.target, + _multiclass_prob_sk_metric, + NUM_CLASSES, + ), + ], +) +class TestBinnedRecallAtPrecision(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp): + self.atol = 0.05 # up to second decimal using 500 thresholds + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinnedRecallAtFixedPrecision, + sk_metric=partial(sk_metric, num_classes=num_classes, min_precision=0.6), + dist_sync_on_step=False, + check_dist_sync_on_step=False, + check_batch=False, + metric_args={ + "num_classes": num_classes, + "min_precision": 0.6, + "num_thresholds": 2000, + }, + ) + + +@pytest.mark.parametrize( + "preds, target, sk_metric, num_classes", + [ + ( + __test_input.preds, + __test_input.target, + _multiclass_average_precision_sk_metric, + NUM_CLASSES, + ), + ], +) +class TestBinnedAveragePrecision(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp): + self.atol = 0.01 # up to second decimal using 200 thresholds + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinnedAveragePrecision, + sk_metric=partial(sk_metric, num_classes=num_classes), + dist_sync_on_step=False, + check_dist_sync_on_step=False, + check_batch=False, + metric_args={ + "num_classes": num_classes, + "num_thresholds": 200, + }, + ) diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py new file mode 100644 index 00000000000..606dc7bce8d --- /dev/null +++ b/torchmetrics/classification/binned_precision_recall.py @@ -0,0 +1,148 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union, List + +import torch +from torchmetrics.metric import Metric + +from torchmetrics.utilities.data import METRIC_EPS, to_onehot + + +# From Lightning's AveragePrecision metric +def _average_precision_compute( + precision: torch.Tensor, + recall: torch.Tensor, + num_classes: int, +) -> Union[List[torch.Tensor], torch.Tensor]: + # Return the step function integral + # The following works because the last entry of precision is + # guaranteed to be 1, as returned by precision_recall_curve + if num_classes == 1: + recall = recall[0, :] + precision = precision[0, :] + return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) + + res = [] + for p, r in zip(precision, recall): + res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) + return res + + +class BinnedPrecisionRecallCurve(Metric): + """Returns a tensor of recalls for a fixed precision threshold. + It is a tensor instead of a single number, because it applies to multi-label inputs. + """ + + TPs: torch.Tensor + FPs: torch.Tensor + FNs: torch.Tensor + thresholds: torch.Tensor + + def __init__( + self, + num_classes: int, + num_thresholds: int = 100, + compute_on_step: bool = False, # will ignore this + **kwargs + ): + # TODO: enable assert after changing testing code in Lightning + # assert not compute_on_step, "computation on each step is not supported" + super().__init__(compute_on_step=False, **kwargs) + self.num_classes = num_classes + self.num_thresholds = num_thresholds + thresholds = torch.arange(num_thresholds) / num_thresholds + self.register_buffer("thresholds", thresholds) + + for name in ("TPs", "FPs", "FNs"): + self.add_state( + name=name, + default=torch.zeros(num_classes, num_thresholds, dtype=torch.long), + dist_reduce_fx="sum", + ) + + def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: + """ + Args + preds: (n_samples, n_classes) tensor + targets: (n_samples, n_classes) tensor + """ + # binary case + if len(preds.shape) == len(targets.shape) == 1: + preds = preds.reshape(-1, 1) + targets = targets.reshape(-1, 1) + + if len(preds.shape) == len(targets.shape) + 1: + targets = to_onehot(targets, num_classes=self.num_classes) + + targets = targets == 1 + # Iterate one threshold at a time to conserve memory + for i in range(self.num_thresholds): + predictions = preds >= self.thresholds[i] + self.TPs[:, i] += (targets & predictions).sum(dim=0) + self.FPs[:, i] += ((~targets) & (predictions)).sum(dim=0) + self.FNs[:, i] += ((targets) & (~predictions)).sum(dim=0) + + def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Returns float tensor of size n_classes""" + precisions = self.TPs / (self.TPs + self.FPs + METRIC_EPS) + recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS) + return (precisions, recalls, self.thresholds) + + +class BinnedAveragePrecision(BinnedPrecisionRecallCurve): + def compute(self) -> Union[List[torch.Tensor], torch.Tensor]: + precisions, recalls, thresholds = super().compute() + return _average_precision_compute(precisions, recalls, self.num_classes) + + +class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve): + def __init__( + self, + num_classes: int, + min_precision: float, + num_thresholds: int = 100, + compute_on_step: bool = False, # will ignore this + **kwargs + ): + # TODO: enable once https://github.com/PyTorchLightning/metrics/pull/122 lands + # assert not compute_on_step, "computation on each step is not supported" + super().__init__( + num_classes=num_classes, + num_thresholds=num_thresholds, + compute_on_step=compute_on_step, + **kwargs + ) + self.min_precision = min_precision + + def compute(self) -> Tuple[torch.Tensor, torch.Tensor]: + """Returns float tensor of size n_classes""" + precisions, recalls, thresholds = super().compute() + + thresholds = thresholds.repeat(self.num_classes, 1) + condition = precisions >= self.min_precision + recalls_at_p = ( + torch.where( + condition, recalls, torch.scalar_tensor(0.0, device=condition.device) + ) + .max(dim=1) + .values + ) + thresholds_at_p = ( + torch.where( + condition, thresholds, torch.scalar_tensor(1e6, device=condition.device) + ) + .min(dim=1) + .values + ) + return (recalls_at_p, thresholds_at_p) From 4bb887df738a5fb5a9c4c9e0a3fced25bc545541 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Thu, 25 Mar 2021 13:51:46 -0700 Subject: [PATCH 02/33] attempt to fix types --- torchmetrics/classification/binned_precision_recall.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 606dc7bce8d..0b670b29a3c 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -67,7 +67,7 @@ def __init__( for name in ("TPs", "FPs", "FNs"): self.add_state( name=name, - default=torch.zeros(num_classes, num_thresholds, dtype=torch.long), + default=torch.zeros(num_classes, num_thresholds), dist_reduce_fx="sum", ) @@ -140,7 +140,7 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor]: ) thresholds_at_p = ( torch.where( - condition, thresholds, torch.scalar_tensor(1e6, device=condition.device) + condition, thresholds, torch.scalar_tensor(1e6, device=condition.device, dtype=thresholds.dtype) ) .min(dim=1) .values From c3a4174e40c059b53298d66d713380bffbc2059d Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Thu, 25 Mar 2021 18:03:24 -0700 Subject: [PATCH 03/33] switch to linspace to make old pytorch happy --- torchmetrics/classification/binned_precision_recall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 0b670b29a3c..3a15ab9b93b 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -61,7 +61,7 @@ def __init__( super().__init__(compute_on_step=False, **kwargs) self.num_classes = num_classes self.num_thresholds = num_thresholds - thresholds = torch.arange(num_thresholds) / num_thresholds + thresholds = torch.linspace(0, 1, num_thresholds) self.register_buffer("thresholds", thresholds) for name in ("TPs", "FPs", "FNs"): From df125ef925a4cb3003b18b0142a30974fbd80b7f Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Thu, 25 Mar 2021 18:05:53 -0700 Subject: [PATCH 04/33] make flake happy --- tests/classification/test_binned_precision_recall.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 619aa0e9f51..5e75d6983cb 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -18,7 +18,6 @@ import pytest import torch -import torch.multiprocessing as mp from sklearn.metrics import average_precision_score as _sk_average_precision_score from torchmetrics.classification.binned_precision_recall import ( BinnedAveragePrecision, @@ -31,7 +30,6 @@ NUM_CLASSES, NUM_BATCHES, BATCH_SIZE, - NUM_PROCESSES, MetricTester, ) From 6ac4b347ea386b82c802ba67b26f98793ac6b281 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Mon, 29 Mar 2021 11:04:34 -0700 Subject: [PATCH 05/33] clean up --- .../test_binned_precision_recall.py | 18 +++-------- .../classification/binned_precision_recall.py | 32 +++---------------- .../classification/average_precision.py | 8 +++++ 3 files changed, 17 insertions(+), 41 deletions(-) diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 5e75d6983cb..9fed5fb482a 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -12,27 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from torchmetrics.functional import precision_recall_curve from functools import partial from typing import Tuple import pytest import torch from sklearn.metrics import average_precision_score as _sk_average_precision_score -from torchmetrics.classification.binned_precision_recall import ( - BinnedAveragePrecision, - BinnedRecallAtFixedPrecision, -) -from tests.classification.inputs import ( - Input, -) -from tests.helpers.testers import ( - NUM_CLASSES, - NUM_BATCHES, - BATCH_SIZE, - MetricTester, -) +from tests.classification.inputs import Input +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester +from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision, BinnedRecallAtFixedPrecision +from torchmetrics.functional import precision_recall_curve torch.manual_seed(42) diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 3a15ab9b93b..0be485a900d 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -11,34 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union, List +from typing import List, Tuple, Union import torch -from torchmetrics.metric import Metric +from torchmetrics.functional.classification.average_precision import _average_precision_compute_with_precision_recall +from torchmetrics.metric import Metric from torchmetrics.utilities.data import METRIC_EPS, to_onehot -# From Lightning's AveragePrecision metric -def _average_precision_compute( - precision: torch.Tensor, - recall: torch.Tensor, - num_classes: int, -) -> Union[List[torch.Tensor], torch.Tensor]: - # Return the step function integral - # The following works because the last entry of precision is - # guaranteed to be 1, as returned by precision_recall_curve - if num_classes == 1: - recall = recall[0, :] - precision = precision[0, :] - return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) - - res = [] - for p, r in zip(precision, recall): - res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) - return res - - class BinnedPrecisionRecallCurve(Metric): """Returns a tensor of recalls for a fixed precision threshold. It is a tensor instead of a single number, because it applies to multi-label inputs. @@ -56,8 +37,7 @@ def __init__( compute_on_step: bool = False, # will ignore this **kwargs ): - # TODO: enable assert after changing testing code in Lightning - # assert not compute_on_step, "computation on each step is not supported" + assert not compute_on_step, "computation on each step is not supported" super().__init__(compute_on_step=False, **kwargs) self.num_classes = num_classes self.num_thresholds = num_thresholds @@ -103,7 +83,7 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: class BinnedAveragePrecision(BinnedPrecisionRecallCurve): def compute(self) -> Union[List[torch.Tensor], torch.Tensor]: precisions, recalls, thresholds = super().compute() - return _average_precision_compute(precisions, recalls, self.num_classes) + return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes) class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve): @@ -115,8 +95,6 @@ def __init__( compute_on_step: bool = False, # will ignore this **kwargs ): - # TODO: enable once https://github.com/PyTorchLightning/metrics/pull/122 lands - # assert not compute_on_step, "computation on each step is not supported" super().__init__( num_classes=num_classes, num_thresholds=num_thresholds, diff --git a/torchmetrics/functional/classification/average_precision.py b/torchmetrics/functional/classification/average_precision.py index 6f3b9328d16..1f3bd156af2 100644 --- a/torchmetrics/functional/classification/average_precision.py +++ b/torchmetrics/functional/classification/average_precision.py @@ -40,6 +40,14 @@ def _average_precision_compute( ) -> Union[List[Tensor], Tensor]: # todo: `sample_weights` is unused precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) + return _average_precision_compute_with_precision_recall(precision, recall, num_classes) + + +def _average_precision_compute_with_precision_recall( + precision: Tensor, + recall: Tensor, + num_classes: int, +) -> Union[List[Tensor], Tensor]: # Return the step function integral # The following works because the last entry of precision is # guaranteed to be 1, as returned by precision_recall_curve From 8205ee32f42710a7e18dd3e4bb80785fe9089a5e Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Wed, 7 Apr 2021 13:42:21 -0700 Subject: [PATCH 06/33] Add more testing, move test input generation to the approproate place --- tests/classification/inputs.py | 26 ++++++ .../test_binned_precision_recall.py | 86 +++++++++++-------- .../classification/binned_precision_recall.py | 44 +++++++--- 3 files changed, 106 insertions(+), 50 deletions(-) diff --git a/tests/classification/inputs.py b/tests/classification/inputs.py index 8376f82b94e..4c95f60e952 100644 --- a/tests/classification/inputs.py +++ b/tests/classification/inputs.py @@ -77,3 +77,29 @@ preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) ) + + +# Generate plausible-looking inputs +def generate_plausible_inputs_multilabel(): + correct_targets = torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) + preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) + targets = torch.zeros_like(preds, dtype=torch.long) + for i in range(preds.shape[0]): + for j in range(preds.shape[1]): + targets[i, j, correct_targets[i, j]] = 1 + preds += torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) * targets / 3 + + preds = preds / preds.sum(dim=2, keepdim=True) + + return Input(preds=preds, target=targets) + + +def generate_plausible_inputs_binary(): + targets = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) + preds = torch.rand(NUM_BATCHES, BATCH_SIZE) + torch.rand(NUM_BATCHES, BATCH_SIZE) * targets / 3 + return Input(preds=preds, target=targets) + + +_input_multilabel_prob_plausible = generate_plausible_inputs_multilabel() + +_input_binary_prob_plausible = generate_plausible_inputs_binary() diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 9fed5fb482a..14be264d875 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -19,34 +19,27 @@ import torch from sklearn.metrics import average_precision_score as _sk_average_precision_score -from tests.classification.inputs import Input -from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester +from tests.classification.inputs import ( + _input_binary_prob, + _input_binary_prob_plausible, + _input_multilabel_prob, + _input_multilabel_prob_plausible, +) +from tests.helpers import seed_all +from tests.helpers.testers import NUM_CLASSES, MetricTester from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision, BinnedRecallAtFixedPrecision from torchmetrics.functional import precision_recall_curve -torch.manual_seed(42) - - -def construct_not_terrible_input(): - correct_targets = torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) - preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) - targets = torch.zeros_like(preds, dtype=torch.long) - for i in range(preds.shape[0]): - for j in range(preds.shape[1]): - targets[i, j, correct_targets[i, j]] = 1 - preds += torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) * targets / 3 - - preds = preds / preds.sum(dim=2, keepdim=True) - - return Input(preds=preds, target=targets) - - -__test_input = construct_not_terrible_input() +seed_all(42) def recall_at_precision_x_multilabel( - precision: torch.Tensor, recall, thresholds: torch.Tensor, min_precision: float + predictions: torch.Tensor, targets: torch.Tensor, min_precision: float ) -> Tuple[float, float]: + precision, recall, thresholds = precision_recall_curve( + predictions, targets, pos_label=1 + ) + try: max_recall, max_precision, best_threshold = max( (r, p, t) @@ -64,15 +57,18 @@ def _multiclass_prob_sk_metric(predictions, targets, num_classes, min_precision) best_thresholds = torch.zeros(num_classes) for i in range(num_classes): - precisions, recalls, thresholds = precision_recall_curve( - predictions[:, i], targets[:, i], pos_label=1 - ) max_recalls[i], best_thresholds[i] = recall_at_precision_x_multilabel( - precisions, recalls, thresholds, min_precision + predictions[:, i], targets[:, i], min_precision ) return max_recalls, best_thresholds +def _binary_prob_sk_metric(predictions, targets, num_classes, min_precision): + return recall_at_precision_x_multilabel( + predictions, targets, min_precision + ) + + def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): return _sk_average_precision_score(targets, predictions, average=None) @@ -80,9 +76,17 @@ def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): @pytest.mark.parametrize( "preds, target, sk_metric, num_classes", [ + (_input_binary_prob.preds, _input_binary_prob.target, _binary_prob_sk_metric, 1), + (_input_binary_prob_plausible.preds, _input_binary_prob_plausible.target, _binary_prob_sk_metric, 1), + ( + _input_multilabel_prob_plausible.preds, + _input_multilabel_prob_plausible.target, + _multiclass_prob_sk_metric, + NUM_CLASSES, + ), ( - __test_input.preds, - __test_input.target, + _input_multilabel_prob.preds, + _input_multilabel_prob.target, _multiclass_prob_sk_metric, NUM_CLASSES, ), @@ -90,20 +94,21 @@ def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): ) class TestBinnedRecallAtPrecision(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) - def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp): - self.atol = 0.05 # up to second decimal using 500 thresholds + @pytest.mark.parametrize("min_precision", [0.1, 0.3, 0.5]) + def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, min_precision): + self.atol = 0.05 # Binned and SKLearn implementations can produce different values self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=BinnedRecallAtFixedPrecision, - sk_metric=partial(sk_metric, num_classes=num_classes, min_precision=0.6), + sk_metric=partial(sk_metric, num_classes=num_classes, min_precision=min_precision), dist_sync_on_step=False, check_dist_sync_on_step=False, check_batch=False, metric_args={ "num_classes": num_classes, - "min_precision": 0.6, + "min_precision": min_precision, "num_thresholds": 2000, }, ) @@ -112,9 +117,17 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp): @pytest.mark.parametrize( "preds, target, sk_metric, num_classes", [ + (_input_binary_prob.preds, _input_binary_prob.target, _multiclass_average_precision_sk_metric, 1), + (_input_binary_prob_plausible.preds, _input_binary_prob_plausible.target, _multiclass_average_precision_sk_metric, 1), + ( + _input_multilabel_prob_plausible.preds, + _input_multilabel_prob_plausible.target, + _multiclass_average_precision_sk_metric, + NUM_CLASSES, + ), ( - __test_input.preds, - __test_input.target, + _input_multilabel_prob.preds, + _input_multilabel_prob.target, _multiclass_average_precision_sk_metric, NUM_CLASSES, ), @@ -122,8 +135,9 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp): ) class TestBinnedAveragePrecision(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) - def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp): - self.atol = 0.01 # up to second decimal using 200 thresholds + @pytest.mark.parametrize("num_thresholds", [200, 300]) + def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, num_thresholds): + self.atol = 0.01 self.run_class_metric_test( ddp=ddp, preds=preds, @@ -135,6 +149,6 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp): check_batch=False, metric_args={ "num_classes": num_classes, - "num_thresholds": 200, + "num_thresholds": num_thresholds, }, ) diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 0be485a900d..1373c1fc063 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -41,7 +41,7 @@ def __init__( super().__init__(compute_on_step=False, **kwargs) self.num_classes = num_classes self.num_thresholds = num_thresholds - thresholds = torch.linspace(0, 1, num_thresholds) + thresholds = torch.linspace(0, 1.0 + METRIC_EPS, num_thresholds) self.register_buffer("thresholds", thresholds) for name in ("TPs", "FPs", "FNs"): @@ -77,12 +77,20 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Returns float tensor of size n_classes""" precisions = self.TPs / (self.TPs + self.FPs + METRIC_EPS) recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS) - return (precisions, recalls, self.thresholds) + # Need to guarantee that last precision=1 and recall=0 + precisions = torch.cat([precisions, torch.ones(self.num_classes, 1, + dtype=precisions.dtype, device=precisions.device)], dim=1) + recalls = torch.cat([recalls, torch.zeros(self.num_classes, 1, + dtype=recalls.dtype, device=recalls.device)], dim=1) + if self.num_classes == 1: + return (precisions[0, :], recalls[0, :], self.thresholds) + else: + return (precisions, recalls, self.thresholds) class BinnedAveragePrecision(BinnedPrecisionRecallCurve): def compute(self) -> Union[List[torch.Tensor], torch.Tensor]: - precisions, recalls, thresholds = super().compute() + precisions, recalls, _ = super().compute() return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes) @@ -106,21 +114,29 @@ def __init__( def compute(self) -> Tuple[torch.Tensor, torch.Tensor]: """Returns float tensor of size n_classes""" precisions, recalls, thresholds = super().compute() - - thresholds = thresholds.repeat(self.num_classes, 1) condition = precisions >= self.min_precision - recalls_at_p = ( + + if self.num_classes == 1: + recall_at_p, index = ( + torch.where( + condition, recalls, torch.scalar_tensor(0.0, device=condition.device) + ) + .max(dim=0) + ) + return recall_at_p, self.thresholds[index] + + recalls_at_p, indices = ( torch.where( condition, recalls, torch.scalar_tensor(0.0, device=condition.device) ) .max(dim=1) - .values - ) - thresholds_at_p = ( - torch.where( - condition, thresholds, torch.scalar_tensor(1e6, device=condition.device, dtype=thresholds.dtype) - ) - .min(dim=1) - .values ) + + thresholds_at_p = torch.zeros_like(recalls_at_p, device=condition.device, dtype=thresholds.dtype) + for i in range(self.num_classes): + if recalls_at_p[i] == 0.0: + thresholds_at_p[i] = 1e6 + else: + thresholds_at_p[i] = self.thresholds[indices[i]] + return (recalls_at_p, thresholds_at_p) From eb70c497a131ff9da751f9a34b439894ba6b74eb Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Wed, 7 Apr 2021 14:29:09 -0700 Subject: [PATCH 07/33] bugfixes and more stable and thorough tests --- .../test_binned_precision_recall.py | 17 ++++++++++------- .../classification/binned_precision_recall.py | 14 +++++++++----- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 14be264d875..8eb4f3af503 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -15,9 +15,11 @@ from functools import partial from typing import Tuple +import numpy as np import pytest import torch from sklearn.metrics import average_precision_score as _sk_average_precision_score +from sklearn.metrics import precision_recall_curve as _sk_precision_recall_curve from tests.classification.inputs import ( _input_binary_prob, @@ -28,7 +30,6 @@ from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, MetricTester from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision, BinnedRecallAtFixedPrecision -from torchmetrics.functional import precision_recall_curve seed_all(42) @@ -36,8 +37,8 @@ def recall_at_precision_x_multilabel( predictions: torch.Tensor, targets: torch.Tensor, min_precision: float ) -> Tuple[float, float]: - precision, recall, thresholds = precision_recall_curve( - predictions, targets, pos_label=1 + precision, recall, thresholds = _sk_precision_recall_curve( + targets, predictions, ) try: @@ -49,7 +50,7 @@ def recall_at_precision_x_multilabel( except ValueError: max_recall, best_threshold = 0, 1e6 - return max_recall, best_threshold + return float(max_recall), float(best_threshold) def _multiclass_prob_sk_metric(predictions, targets, num_classes, min_precision): @@ -94,9 +95,11 @@ def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): ) class TestBinnedRecallAtPrecision(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("min_precision", [0.1, 0.3, 0.5]) + @pytest.mark.parametrize("min_precision", [0.1, 0.3, 0.5, 0.8]) def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, min_precision): - self.atol = 0.05 # Binned and SKLearn implementations can produce different values + self.atol = 0.01 + # rounding will simulate binning for both implementations + preds = torch.Tensor(np.round(preds.numpy(), 2)) + 1e-6 self.run_class_metric_test( ddp=ddp, preds=preds, @@ -109,7 +112,7 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, min_precisi metric_args={ "num_classes": num_classes, "min_precision": min_precision, - "num_thresholds": 2000, + "num_thresholds": 101, }, ) diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 1373c1fc063..b4dd9a4e1db 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -75,17 +75,18 @@ def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Returns float tensor of size n_classes""" - precisions = self.TPs / (self.TPs + self.FPs + METRIC_EPS) + precisions = (self.TPs + METRIC_EPS) / (self.TPs + self.FPs + METRIC_EPS) recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS) # Need to guarantee that last precision=1 and recall=0 precisions = torch.cat([precisions, torch.ones(self.num_classes, 1, dtype=precisions.dtype, device=precisions.device)], dim=1) recalls = torch.cat([recalls, torch.zeros(self.num_classes, 1, dtype=recalls.dtype, device=recalls.device)], dim=1) + thresholds = torch.cat([self.thresholds, torch.ones(1, dtype=recalls.dtype, device=recalls.device)], dim=0) if self.num_classes == 1: - return (precisions[0, :], recalls[0, :], self.thresholds) + return (precisions[0, :], recalls[0, :], thresholds) else: - return (precisions, recalls, self.thresholds) + return (precisions, recalls, thresholds) class BinnedAveragePrecision(BinnedPrecisionRecallCurve): @@ -123,7 +124,10 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor]: ) .max(dim=0) ) - return recall_at_p, self.thresholds[index] + if recall_at_p == 0.0: + return recall_at_p, torch.scalar_tensor(1e6, device=condition.device) + else: + return recall_at_p, thresholds[index] recalls_at_p, indices = ( torch.where( @@ -137,6 +141,6 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor]: if recalls_at_p[i] == 0.0: thresholds_at_p[i] = 1e6 else: - thresholds_at_p[i] = self.thresholds[indices[i]] + thresholds_at_p[i] = thresholds[indices[i]] return (recalls_at_p, thresholds_at_p) From 15c07f2cad8d273d79acf6806e8773d3db1d4b41 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Wed, 7 Apr 2021 14:30:57 -0700 Subject: [PATCH 08/33] flake8 --- tests/classification/test_binned_precision_recall.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 8eb4f3af503..cd2216b49e8 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -121,7 +121,12 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, min_precisi "preds, target, sk_metric, num_classes", [ (_input_binary_prob.preds, _input_binary_prob.target, _multiclass_average_precision_sk_metric, 1), - (_input_binary_prob_plausible.preds, _input_binary_prob_plausible.target, _multiclass_average_precision_sk_metric, 1), + ( + _input_binary_prob_plausible.preds, + _input_binary_prob_plausible.target, + _multiclass_average_precision_sk_metric, + 1, + ), ( _input_multilabel_prob_plausible.preds, _input_multilabel_prob_plausible.target, From e1bb5dc22432968c50fea5461445e3b86cd8278f Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Wed, 7 Apr 2021 16:03:40 -0700 Subject: [PATCH 09/33] Reuse python zip-based implementation as it can't be reproduced with torch.where/max --- .../test_binned_precision_recall.py | 11 +++-- .../classification/binned_precision_recall.py | 49 +++++++++---------- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index cd2216b49e8..8ea8caf24be 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -95,9 +95,9 @@ def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): ) class TestBinnedRecallAtPrecision(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("min_precision", [0.1, 0.3, 0.5, 0.8]) + @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8, 0.95]) def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, min_precision): - self.atol = 0.01 + self.atol = 0.02 # rounding will simulate binning for both implementations preds = torch.Tensor(np.round(preds.numpy(), 2)) + 1e-6 self.run_class_metric_test( @@ -143,9 +143,12 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, min_precisi ) class TestBinnedAveragePrecision(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("num_thresholds", [200, 300]) + @pytest.mark.parametrize("num_thresholds", [101, 301]) def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, num_thresholds): - self.atol = 0.01 + self.atol = 0.02 + # rounding will simulate binning for both implementations + preds = torch.Tensor(np.round(preds.numpy(), 2)) + 1e-6 + self.run_class_metric_test( ddp=ddp, preds=preds, diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index b4dd9a4e1db..43c8b5b84c6 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -20,6 +20,24 @@ from torchmetrics.utilities.data import METRIC_EPS, to_onehot +def _recall_at_precision( + precision: torch.Tensor, recall: torch.Tensor, thresholds: torch.Tensor, min_precision: float +): + try: + max_recall, max_precision, best_threshold = max( + (r, p, t) + for p, r, t in zip(precision, recall, thresholds) + if p >= min_precision + ) + except ValueError: + max_recall = torch.tensor(0.0, device=recall.device, dtype=recall.dtype) + + if max_recall == 0.0: + best_threshold = torch.tensor(1e6, device=thresholds.device, dtype=thresholds.dtype) + + return max_recall, best_threshold + + class BinnedPrecisionRecallCurve(Metric): """Returns a tensor of recalls for a fixed precision threshold. It is a tensor instead of a single number, because it applies to multi-label inputs. @@ -41,7 +59,7 @@ def __init__( super().__init__(compute_on_step=False, **kwargs) self.num_classes = num_classes self.num_thresholds = num_thresholds - thresholds = torch.linspace(0, 1.0 + METRIC_EPS, num_thresholds) + thresholds = torch.linspace(0, 1.0, num_thresholds) self.register_buffer("thresholds", thresholds) for name in ("TPs", "FPs", "FNs"): @@ -115,32 +133,13 @@ def __init__( def compute(self) -> Tuple[torch.Tensor, torch.Tensor]: """Returns float tensor of size n_classes""" precisions, recalls, thresholds = super().compute() - condition = precisions >= self.min_precision if self.num_classes == 1: - recall_at_p, index = ( - torch.where( - condition, recalls, torch.scalar_tensor(0.0, device=condition.device) - ) - .max(dim=0) - ) - if recall_at_p == 0.0: - return recall_at_p, torch.scalar_tensor(1e6, device=condition.device) - else: - return recall_at_p, thresholds[index] - - recalls_at_p, indices = ( - torch.where( - condition, recalls, torch.scalar_tensor(0.0, device=condition.device) - ) - .max(dim=1) - ) + return _recall_at_precision(precisions, recalls, thresholds, self.min_precision) - thresholds_at_p = torch.zeros_like(recalls_at_p, device=condition.device, dtype=thresholds.dtype) + recalls_at_p = torch.zeros(self.num_classes, device=recalls.device, dtype=recalls.dtype) + thresholds_at_p = torch.zeros(self.num_classes, device=thresholds.device, dtype=thresholds.dtype) for i in range(self.num_classes): - if recalls_at_p[i] == 0.0: - thresholds_at_p[i] = 1e6 - else: - thresholds_at_p[i] = thresholds[indices[i]] - + recalls_at_p[i], thresholds_at_p[i] = _recall_at_precision( + precisions[i, :], recalls[i, :], thresholds, self.min_precision) return (recalls_at_p, thresholds_at_p) From c39384afc08bc7dcad31a3c177fef339243b2ba3 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Wed, 7 Apr 2021 16:08:53 -0700 Subject: [PATCH 10/33] address comments --- tests/classification/inputs.py | 14 +++++----- .../test_binned_precision_recall.py | 28 +++++++++---------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/classification/inputs.py b/tests/classification/inputs.py index 4c95f60e952..a1a222d38cd 100644 --- a/tests/classification/inputs.py +++ b/tests/classification/inputs.py @@ -80,23 +80,23 @@ # Generate plausible-looking inputs -def generate_plausible_inputs_multilabel(): - correct_targets = torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) - preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) +def generate_plausible_inputs_multilabel(num_classes=NUM_CLASSES, num_batches=NUM_BATCHES, batch_size=BATCH_SIZE): + correct_targets = torch.randint(high=num_classes, size=(num_batches, batch_size)) + preds = torch.rand(num_batches, batch_size, num_classes) targets = torch.zeros_like(preds, dtype=torch.long) for i in range(preds.shape[0]): for j in range(preds.shape[1]): targets[i, j, correct_targets[i, j]] = 1 - preds += torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) * targets / 3 + preds += torch.rand(num_batches, batch_size, num_classes) * targets / 3 preds = preds / preds.sum(dim=2, keepdim=True) return Input(preds=preds, target=targets) -def generate_plausible_inputs_binary(): - targets = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) - preds = torch.rand(NUM_BATCHES, BATCH_SIZE) + torch.rand(NUM_BATCHES, BATCH_SIZE) * targets / 3 +def generate_plausible_inputs_binary(num_batches=NUM_BATCHES, batch_size=BATCH_SIZE): + targets = torch.randint(high=2, size=(num_batches, batch_size)) + preds = torch.rand(num_batches, batch_size) + torch.rand(num_batches, batch_size) * targets / 3 return Input(preds=preds, target=targets) diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 8ea8caf24be..47a47e2def3 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -23,9 +23,9 @@ from tests.classification.inputs import ( _input_binary_prob, - _input_binary_prob_plausible, - _input_multilabel_prob, - _input_multilabel_prob_plausible, + _input_binary_prob_plausible as _input_binary_prob_ok, + _input_multilabel_prob as _input_mlb_prob, + _input_multilabel_prob_plausible as _input_mlb_prob_ok, ) from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, MetricTester @@ -78,16 +78,16 @@ def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): "preds, target, sk_metric, num_classes", [ (_input_binary_prob.preds, _input_binary_prob.target, _binary_prob_sk_metric, 1), - (_input_binary_prob_plausible.preds, _input_binary_prob_plausible.target, _binary_prob_sk_metric, 1), + (_input_binary_prob_ok.preds, _input_binary_prob_ok.target, _binary_prob_sk_metric, 1), ( - _input_multilabel_prob_plausible.preds, - _input_multilabel_prob_plausible.target, + _input_mlb_prob_ok.preds, + _input_mlb_prob_ok.target, _multiclass_prob_sk_metric, NUM_CLASSES, ), ( - _input_multilabel_prob.preds, - _input_multilabel_prob.target, + _input_mlb_prob.preds, + _input_mlb_prob.target, _multiclass_prob_sk_metric, NUM_CLASSES, ), @@ -122,20 +122,20 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, min_precisi [ (_input_binary_prob.preds, _input_binary_prob.target, _multiclass_average_precision_sk_metric, 1), ( - _input_binary_prob_plausible.preds, - _input_binary_prob_plausible.target, + _input_binary_prob_ok.preds, + _input_binary_prob_ok.target, _multiclass_average_precision_sk_metric, 1, ), ( - _input_multilabel_prob_plausible.preds, - _input_multilabel_prob_plausible.target, + _input_mlb_prob_ok.preds, + _input_mlb_prob_ok.target, _multiclass_average_precision_sk_metric, NUM_CLASSES, ), ( - _input_multilabel_prob.preds, - _input_multilabel_prob.target, + _input_mlb_prob.preds, + _input_mlb_prob.target, _multiclass_average_precision_sk_metric, NUM_CLASSES, ), From b6b289e0ad3058aff36dd87d70c1a8227445a8f9 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Wed, 7 Apr 2021 16:09:47 -0700 Subject: [PATCH 11/33] isort --- tests/classification/test_binned_precision_recall.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 47a47e2def3..4516a76be3f 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -21,12 +21,10 @@ from sklearn.metrics import average_precision_score as _sk_average_precision_score from sklearn.metrics import precision_recall_curve as _sk_precision_recall_curve -from tests.classification.inputs import ( - _input_binary_prob, - _input_binary_prob_plausible as _input_binary_prob_ok, - _input_multilabel_prob as _input_mlb_prob, - _input_multilabel_prob_plausible as _input_mlb_prob_ok, -) +from tests.classification.inputs import _input_binary_prob +from tests.classification.inputs import _input_binary_prob_plausible as _input_binary_prob_ok +from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from tests.classification.inputs import _input_multilabel_prob_plausible as _input_mlb_prob_ok from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, MetricTester from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision, BinnedRecallAtFixedPrecision From a3c5dd2f85f7be67c3b5ecb78be7ba289039cfaa Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Wed, 7 Apr 2021 19:13:28 -0700 Subject: [PATCH 12/33] Add docs and doctests, make APIs same as non-binned versions --- docs/source/references/modules.rst | 12 ++ tests/classification/inputs.py | 2 +- .../test_binned_precision_recall.py | 18 +-- tests/classification/test_precision_recall.py | 2 +- torchmetrics/__init__.py | 3 + torchmetrics/classification/__init__.py | 1 + .../classification/binned_precision_recall.py | 146 +++++++++++++++--- 7 files changed, 155 insertions(+), 29 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 29d58252e16..e7f892e6bb1 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -126,6 +126,12 @@ AveragePrecision .. autoclass:: torchmetrics.AveragePrecision :noindex: +BinnedAveragePrecision +~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.BinnedAveragePrecision + :noindex: + AUC ~~~ @@ -198,6 +204,12 @@ PrecisionRecallCurve .. autoclass:: torchmetrics.PrecisionRecallCurve :noindex: +BinnedPrecisionRecallCurve +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.BinnedPrecisionRecallCurve + :noindex: + Recall ~~~~~~ diff --git a/tests/classification/inputs.py b/tests/classification/inputs.py index a1a222d38cd..c6ce83a4069 100644 --- a/tests/classification/inputs.py +++ b/tests/classification/inputs.py @@ -97,7 +97,7 @@ def generate_plausible_inputs_multilabel(num_classes=NUM_CLASSES, num_batches=NU def generate_plausible_inputs_binary(num_batches=NUM_BATCHES, batch_size=BATCH_SIZE): targets = torch.randint(high=2, size=(num_batches, batch_size)) preds = torch.rand(num_batches, batch_size) + torch.rand(num_batches, batch_size) * targets / 3 - return Input(preds=preds, target=targets) + return Input(preds=preds / (preds.max() + 0.01), target=targets) _input_multilabel_prob_plausible = generate_plausible_inputs_multilabel() diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 4516a76be3f..0cb68b0d3cd 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -69,7 +69,8 @@ def _binary_prob_sk_metric(predictions, targets, num_classes, min_precision): def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): - return _sk_average_precision_score(targets, predictions, average=None) + # replace nan with 0 + return torch.nan_to_num(_sk_average_precision_score(targets, predictions, average=None)) @pytest.mark.parametrize( @@ -93,20 +94,20 @@ def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): ) class TestBinnedRecallAtPrecision(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8, 0.95]) - def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, min_precision): + def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, min_precision): self.atol = 0.02 # rounding will simulate binning for both implementations preds = torch.Tensor(np.round(preds.numpy(), 2)) + 1e-6 + self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=BinnedRecallAtFixedPrecision, sk_metric=partial(sk_metric, num_classes=num_classes, min_precision=min_precision), - dist_sync_on_step=False, - check_dist_sync_on_step=False, - check_batch=False, + dist_sync_on_step=dist_sync_on_step, metric_args={ "num_classes": num_classes, "min_precision": min_precision, @@ -141,8 +142,9 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, min_precisi ) class TestBinnedAveragePrecision(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("num_thresholds", [101, 301]) - def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, num_thresholds): + def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, num_thresholds): self.atol = 0.02 # rounding will simulate binning for both implementations preds = torch.Tensor(np.round(preds.numpy(), 2)) + 1e-6 @@ -153,9 +155,7 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, num_thresho target=target, metric_class=BinnedAveragePrecision, sk_metric=partial(sk_metric, num_classes=num_classes), - dist_sync_on_step=False, - check_dist_sync_on_step=False, - check_batch=False, + dist_sync_on_step=dist_sync_on_step, metric_args={ "num_classes": num_classes, "num_thresholds": num_thresholds, diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index dedb86d5178..0b65ba8fda8 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -195,7 +195,7 @@ def test_no_support(metric_class, metric_fn): ) class TestPrecisionRecall(MetricTester): - @pytest.mark.parametrize("ddp", [False]) + @pytest.mark.parametrize("ddp", [False, True]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_precision_recall_class( self, diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index 01d7254edba..7e9fd3e5d8a 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -26,6 +26,9 @@ ROC, Accuracy, AveragePrecision, + BinnedAveragePrecision, + BinnedPrecisionRecallCurve, + BinnedRecallAtFixedPrecision, CohenKappa, ConfusionMatrix, FBeta, diff --git a/torchmetrics/classification/__init__.py b/torchmetrics/classification/__init__.py index 06be458d60a..b54f42b2e5b 100644 --- a/torchmetrics/classification/__init__.py +++ b/torchmetrics/classification/__init__.py @@ -26,3 +26,4 @@ from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401 from torchmetrics.classification.roc import ROC # noqa: F401 from torchmetrics.classification.stat_scores import StatScores # noqa: F401 +from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve, BinnedAveragePrecision, BinnedRecallAtFixedPrecision # noqa: F401 \ No newline at end of file diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 43c8b5b84c6..2eada111a99 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch @@ -39,8 +39,72 @@ def _recall_at_precision( class BinnedPrecisionRecallCurve(Metric): - """Returns a tensor of recalls for a fixed precision threshold. - It is a tensor instead of a single number, because it applies to multi-label inputs. + """ + Computes precision-recall pairs for different thresholds. Works for both + binary and multiclass problems. In the case of multiclass, the values will + be calculated based on a one-vs-the-rest approach. + + Computation is performed in constant-memory by computing precision and recall + for ``num_thresholds`` buckets/thresholds (evenly distributed between 0 and 1). + + Forward accepts + + - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor + with probabilities, where C is the number of classes. + + - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels + + Args: + num_classes: integer with number of classes. For binary, set to 1. + num_thresholds: number of bins used for computation. More bins will lead to more detailed + curve and accurate estimates, but will be slower and consume more memory. Default 100 + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example (binary case): + >>> from torchmetrics import BinnedPrecisionRecallCurve + >>> pred = torch.tensor([0, 0.1, 0.8, 0.4]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> pr_curve = BinnedPrecisionRecallCurve(num_classes=1, num_thresholds=5) + >>> precision, recall, thresholds = pr_curve(pred, target) + >>> precision + tensor([0.5000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000]) + >>> recall + tensor([1.0000, 0.5000, 0.5000, 0.5000, 0.0000, 0.0000]) + >>> thresholds + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) + + Example (multiclass case): + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> pr_curve = BinnedPrecisionRecallCurve(num_classes=5, num_thresholds=3) + >>> precision, recall, thresholds = pr_curve(pred, target) + >>> precision # doctest: +NORMALIZE_WHITESPACE + [tensor([0.2500, 1.0000, 1.0000, 1.0000]), + tensor([0.2500, 1.0000, 1.0000, 1.0000]), + tensor([2.5000e-01, 1.0000e-06, 1.0000e+00, 1.0000e+00]), + tensor([2.5000e-01, 1.0000e-06, 1.0000e+00, 1.0000e+00]), + tensor([2.5000e-07, 1.0000e+00, 1.0000e+00, 1.0000e+00])] + >>> recall # doctest: +NORMALIZE_WHITESPACE + [tensor([1.0000, 1.0000, 0.0000, 0.0000]), + tensor([1.0000, 1.0000, 0.0000, 0.0000]), + tensor([1.0000, 0.0000, 0.0000, 0.0000]), + tensor([1.0000, 0.0000, 0.0000, 0.0000]), + tensor([0., 0., 0., 0.])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000]), + tensor([0.0000, 0.5000, 1.0000])] """ TPs: torch.Tensor @@ -52,11 +116,16 @@ def __init__( self, num_classes: int, num_thresholds: int = 100, - compute_on_step: bool = False, # will ignore this - **kwargs + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, ): - assert not compute_on_step, "computation on each step is not supported" - super().__init__(compute_on_step=False, **kwargs) + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + self.num_classes = num_classes self.num_thresholds = num_thresholds thresholds = torch.linspace(0, 1.0, num_thresholds) @@ -65,7 +134,7 @@ def __init__( for name in ("TPs", "FPs", "FNs"): self.add_state( name=name, - default=torch.zeros(num_classes, num_thresholds), + default=torch.zeros(num_classes, num_thresholds, dtype=torch.float32), dist_reduce_fx="sum", ) @@ -95,19 +164,58 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Returns float tensor of size n_classes""" precisions = (self.TPs + METRIC_EPS) / (self.TPs + self.FPs + METRIC_EPS) recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS) - # Need to guarantee that last precision=1 and recall=0 + + # Need to guarantee that last precision=1 and recall=0, similar to precision_recall_curve precisions = torch.cat([precisions, torch.ones(self.num_classes, 1, dtype=precisions.dtype, device=precisions.device)], dim=1) recalls = torch.cat([recalls, torch.zeros(self.num_classes, 1, dtype=recalls.dtype, device=recalls.device)], dim=1) - thresholds = torch.cat([self.thresholds, torch.ones(1, dtype=recalls.dtype, device=recalls.device)], dim=0) if self.num_classes == 1: - return (precisions[0, :], recalls[0, :], thresholds) + return (precisions[0, :], recalls[0, :], self.thresholds) else: - return (precisions, recalls, thresholds) + return (list(precisions), list(recalls), [self.thresholds for _ in range(self.num_classes)]) class BinnedAveragePrecision(BinnedPrecisionRecallCurve): + """ + Computes the average precision score, which summarises the precision recall + curve into one number. Works for both binary and multiclass problems. + In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. + + Forward accepts + + - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor + with probabilities, where C is the number of classes. + + - ``target`` (long tensor): ``(N, ...)`` with integer labels + + Args: + num_classes: integer with number of classes. Not nessesary to provide + for binary problems. + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example (binary case): + >>> from torchmetrics import BinnedAveragePrecision + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 1, 1]) + >>> average_precision = BinnedAveragePrecision(num_classes=1, num_thresholds=10) + >>> average_precision(pred, target) + tensor(1.0000) + + Example (multiclass case): + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> average_precision = BinnedAveragePrecision(num_classes=5, num_thresholds=10) + >>> average_precision(pred, target) + [tensor(1.0000), tensor(1.0000), tensor(0.2500), tensor(0.2500), tensor(-0.)] + """ + def compute(self) -> Union[List[torch.Tensor], torch.Tensor]: precisions, recalls, _ = super().compute() return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes) @@ -119,14 +227,16 @@ def __init__( num_classes: int, min_precision: float, num_thresholds: int = 100, - compute_on_step: bool = False, # will ignore this - **kwargs + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, ): super().__init__( num_classes=num_classes, num_thresholds=num_thresholds, compute_on_step=compute_on_step, - **kwargs + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, ) self.min_precision = min_precision @@ -137,9 +247,9 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor]: if self.num_classes == 1: return _recall_at_precision(precisions, recalls, thresholds, self.min_precision) - recalls_at_p = torch.zeros(self.num_classes, device=recalls.device, dtype=recalls.dtype) - thresholds_at_p = torch.zeros(self.num_classes, device=thresholds.device, dtype=thresholds.dtype) + recalls_at_p = torch.zeros(self.num_classes, device=recalls[0].device, dtype=recalls[0].dtype) + thresholds_at_p = torch.zeros(self.num_classes, device=thresholds[0].device, dtype=thresholds[0].dtype) for i in range(self.num_classes): recalls_at_p[i], thresholds_at_p[i] = _recall_at_precision( - precisions[i, :], recalls[i, :], thresholds, self.min_precision) + precisions[i], recalls[i], thresholds[i], self.min_precision) return (recalls_at_p, thresholds_at_p) From 6e568d9f01d87d4d3e3bf99555c322cdb2bedb30 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Wed, 7 Apr 2021 19:15:41 -0700 Subject: [PATCH 13/33] pep8 --- torchmetrics/classification/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmetrics/classification/__init__.py b/torchmetrics/classification/__init__.py index b54f42b2e5b..0a300c67a93 100644 --- a/torchmetrics/classification/__init__.py +++ b/torchmetrics/classification/__init__.py @@ -26,4 +26,6 @@ from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401 from torchmetrics.classification.roc import ROC # noqa: F401 from torchmetrics.classification.stat_scores import StatScores # noqa: F401 -from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve, BinnedAveragePrecision, BinnedRecallAtFixedPrecision # noqa: F401 \ No newline at end of file +from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401 +from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision # noqa: F401 +from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401 \ No newline at end of file From d3a5d9f7323337470e3f6db64f9d31da7531b9f4 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Wed, 7 Apr 2021 19:16:51 -0700 Subject: [PATCH 14/33] isort --- torchmetrics/classification/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmetrics/classification/__init__.py b/torchmetrics/classification/__init__.py index 0a300c67a93..05cbca4e4a3 100644 --- a/torchmetrics/classification/__init__.py +++ b/torchmetrics/classification/__init__.py @@ -15,6 +15,9 @@ from torchmetrics.classification.auc import AUC # noqa: F401 from torchmetrics.classification.auroc import AUROC # noqa: F401 from torchmetrics.classification.average_precision import AveragePrecision # noqa: F401 +from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision # noqa: F401 +from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401 +from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401 from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401 from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401 from torchmetrics.classification.f_beta import F1, FBeta # noqa: F401 @@ -26,6 +29,3 @@ from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401 from torchmetrics.classification.roc import ROC # noqa: F401 from torchmetrics.classification.stat_scores import StatScores # noqa: F401 -from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401 -from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision # noqa: F401 -from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401 \ No newline at end of file From 6d5b8b23acc1794e5ad6afac82b70a2cf039db34 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Wed, 7 Apr 2021 19:23:36 -0700 Subject: [PATCH 15/33] doctests likes longer title underlines :O --- docs/source/references/modules.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index e7f892e6bb1..89f60fab3fd 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -127,7 +127,7 @@ AveragePrecision :noindex: BinnedAveragePrecision -~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: torchmetrics.BinnedAveragePrecision :noindex: @@ -205,7 +205,7 @@ PrecisionRecallCurve :noindex: BinnedPrecisionRecallCurve -~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: torchmetrics.BinnedPrecisionRecallCurve :noindex: From a1a729450c7e1a38ca2bee1e8cb95d72afe20473 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Wed, 7 Apr 2021 19:30:33 -0700 Subject: [PATCH 16/33] use numpy's nan_to_num --- tests/classification/test_binned_precision_recall.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 0cb68b0d3cd..3b02ea7713c 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -70,7 +70,7 @@ def _binary_prob_sk_metric(predictions, targets, num_classes, min_precision): def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): # replace nan with 0 - return torch.nan_to_num(_sk_average_precision_score(targets, predictions, average=None)) + return np.nan_to_num(_sk_average_precision_score(targets, predictions, average=None)) @pytest.mark.parametrize( From 4e276bec26dfa7b5689bdb07eeb432e354604f88 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Wed, 7 Apr 2021 20:56:11 -0700 Subject: [PATCH 17/33] add atol to bleu tests to make them more stable --- tests/functional/test_nlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/functional/test_nlp.py b/tests/functional/test_nlp.py index 76bf77041da..655db78173b 100644 --- a/tests/functional/test_nlp.py +++ b/tests/functional/test_nlp.py @@ -63,11 +63,11 @@ def test_bleu_score(weights, n_gram, smooth_func, smooth): smoothing_function=smooth_func, ) pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth) - assert torch.allclose(pl_output, tensor(nltk_output)) + assert torch.allclose(pl_output, tensor(nltk_output), atol=1e-3) nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func) pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth) - assert torch.allclose(pl_output, tensor(nltk_output)) + assert torch.allclose(pl_output, tensor(nltk_output), atol=1e-3) def test_bleu_empty(): From 9ce97452120bcae3460149579df0861659f20f10 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Thu, 8 Apr 2021 10:08:51 -0700 Subject: [PATCH 18/33] atol=1e-2 for bleu --- tests/functional/test_nlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/functional/test_nlp.py b/tests/functional/test_nlp.py index 655db78173b..8cfe6f738cf 100644 --- a/tests/functional/test_nlp.py +++ b/tests/functional/test_nlp.py @@ -63,11 +63,11 @@ def test_bleu_score(weights, n_gram, smooth_func, smooth): smoothing_function=smooth_func, ) pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth) - assert torch.allclose(pl_output, tensor(nltk_output), atol=1e-3) + assert torch.allclose(pl_output, tensor(nltk_output), atol=1e-2) nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func) pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth) - assert torch.allclose(pl_output, tensor(nltk_output), atol=1e-3) + assert torch.allclose(pl_output, tensor(nltk_output), atol=1e-2) def test_bleu_empty(): From d19e52e270c70e419e3f8617fd104a1bf723786d Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Thu, 8 Apr 2021 13:05:19 -0700 Subject: [PATCH 19/33] add more docs --- .../classification/binned_precision_recall.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 2eada111a99..cb0ad4a4edc 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -182,6 +182,9 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve): curve into one number. Works for both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. + Computation is performed in constant-memory by computing precision and recall + for ``num_thresholds`` buckets/thresholds (evenly distributed between 0 and 1). + Forward accepts - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor @@ -192,6 +195,8 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve): Args: num_classes: integer with number of classes. Not nessesary to provide for binary problems. + num_thresholds: number of bins used for computation. More bins will lead to more detailed + curve and accurate estimates, but will be slower and consume more memory. Default 100 compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True process_group: @@ -222,6 +227,47 @@ def compute(self) -> Union[List[torch.Tensor], torch.Tensor]: class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve): + """ + Computes the higest possible recall value given the minimum precision thresholds provided. + + Computation is performed in constant-memory by computing precision and recall + for ``num_thresholds`` buckets/thresholds (evenly distributed between 0 and 1). + + Forward accepts + + - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor + with probabilities, where C is the number of classes. + + - ``target`` (long tensor): ``(N, ...)`` with integer labels + + Args: + num_classes: integer with number of classes. Provide 1 for for binary problems. + min_precision: float value specifying minimum precision threshold. + num_thresholds: number of bins used for computation. More bins will lead to more detailed + curve and accurate estimates, but will be slower and consume more memory. Default 100 + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example (binary case): + >>> from torchmetrics import BinnedRecallAtFixedPrecision + >>> pred = torch.tensor([0, 0.2, 0.5, 0.8]) + >>> target = torch.tensor([0, 1, 1, 0]) + >>> average_precision = BinnedRecallAtFixedPrecision(num_classes=1, num_thresholds=10, min_precision=0.5) + >>> average_precision(pred, target) + (tensor(1.0000), tensor(0.1111)) + + Example (multiclass case): + >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> average_precision = BinnedRecallAtFixedPrecision(num_classes=5, num_thresholds=10, min_precision=0.5) + >>> average_precision(pred, target) + (tensor([1.0000, 1.0000, 0.0000, 0.0000, 0.0000]), tensor([6.6667e-01, 6.6667e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) + """ def __init__( self, num_classes: int, From 704e7f6cf07ec64a9bc64860dd8755c690d27a70 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Thu, 8 Apr 2021 13:18:56 -0700 Subject: [PATCH 20/33] pep8 --- torchmetrics/classification/binned_precision_recall.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index cb0ad4a4edc..85884f06085 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -265,8 +265,9 @@ class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve): ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> average_precision = BinnedRecallAtFixedPrecision(num_classes=5, num_thresholds=10, min_precision=0.5) - >>> average_precision(pred, target) - (tensor([1.0000, 1.0000, 0.0000, 0.0000, 0.0000]), tensor([6.6667e-01, 6.6667e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) + >>> average_precision(pred, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([1.0000, 1.0000, 0.0000, 0.0000, 0.0000]), + tensor([6.6667e-01, 6.6667e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) """ def __init__( self, From 00e2d842316b7306673864efa8b3c82d9796a582 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Fri, 9 Apr 2021 12:06:06 -0700 Subject: [PATCH 21/33] remove nlp test hack --- tests/functional/test_nlp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/functional/test_nlp.py b/tests/functional/test_nlp.py index 8cfe6f738cf..76bf77041da 100644 --- a/tests/functional/test_nlp.py +++ b/tests/functional/test_nlp.py @@ -63,11 +63,11 @@ def test_bleu_score(weights, n_gram, smooth_func, smooth): smoothing_function=smooth_func, ) pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth) - assert torch.allclose(pl_output, tensor(nltk_output), atol=1e-2) + assert torch.allclose(pl_output, tensor(nltk_output)) nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func) pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth) - assert torch.allclose(pl_output, tensor(nltk_output), atol=1e-2) + assert torch.allclose(pl_output, tensor(nltk_output)) def test_bleu_empty(): From 88f83d6cbf531d392f811b3f21023bf442419643 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Apr 2021 16:49:36 +0200 Subject: [PATCH 22/33] abc --- docs/source/references/modules.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 89f60fab3fd..a1fcb5064ad 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -126,12 +126,6 @@ AveragePrecision .. autoclass:: torchmetrics.AveragePrecision :noindex: -BinnedAveragePrecision -~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: torchmetrics.BinnedAveragePrecision - :noindex: - AUC ~~~ @@ -144,6 +138,12 @@ AUROC .. autoclass:: torchmetrics.AUROC :noindex: +BinnedAveragePrecision +~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.BinnedAveragePrecision + :noindex: + CohenKappa ~~~~~~~~~~ From ee0d54181a0db9d22809ec0839a230ee9ca289d8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Apr 2021 16:50:43 +0200 Subject: [PATCH 23/33] abc --- docs/source/references/modules.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index a1fcb5064ad..2a9a13542c3 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -144,6 +144,12 @@ BinnedAveragePrecision .. autoclass:: torchmetrics.BinnedAveragePrecision :noindex: +BinnedPrecisionRecallCurve +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.BinnedPrecisionRecallCurve + :noindex: + CohenKappa ~~~~~~~~~~ @@ -204,12 +210,6 @@ PrecisionRecallCurve .. autoclass:: torchmetrics.PrecisionRecallCurve :noindex: -BinnedPrecisionRecallCurve -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: torchmetrics.BinnedPrecisionRecallCurve - :noindex: - Recall ~~~~~~ From 7e673c1bd8410cdea9b59feae13a3ae0cc8bbb84 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Tue, 13 Apr 2021 11:58:35 -0700 Subject: [PATCH 24/33] address comments --- docs/source/references/modules.rst | 7 +++++++ tests/classification/test_binned_precision_recall.py | 5 +++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 4ed0eb56824..1414c4690d3 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -216,6 +216,13 @@ Recall .. autoclass:: torchmetrics.Recall :noindex: +BinnedRecallAtFixedPrecision +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.BinnedRecallAtFixedPrecision + :noindex: + + ROC ~~~ diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 3b02ea7713c..9bdd3770fd3 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -93,11 +93,11 @@ def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): ], ) class TestBinnedRecallAtPrecision(MetricTester): + atol = 0.02 @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8, 0.95]) def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, min_precision): - self.atol = 0.02 # rounding will simulate binning for both implementations preds = torch.Tensor(np.round(preds.numpy(), 2)) + 1e-6 @@ -141,11 +141,12 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_o ], ) class TestBinnedAveragePrecision(MetricTester): + atol = 0.02 + @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("num_thresholds", [101, 301]) def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, num_thresholds): - self.atol = 0.02 # rounding will simulate binning for both implementations preds = torch.Tensor(np.round(preds.numpy(), 2)) + 1e-6 From 322e7e34a3fd39a4f845ac2cc70a8d19ecc50a74 Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Tue, 13 Apr 2021 12:01:22 -0700 Subject: [PATCH 25/33] pep8 --- tests/classification/test_binned_precision_recall.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 9bdd3770fd3..c74094519ce 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -94,6 +94,7 @@ def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): ) class TestBinnedRecallAtPrecision(MetricTester): atol = 0.02 + @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8, 0.95]) From 291f04bcb3754087f1afdeeac914af453a24adce Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Tue, 13 Apr 2021 12:04:16 -0700 Subject: [PATCH 26/33] abc --- docs/source/references/modules.rst | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 86818ce8c36..41c855c7159 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -150,6 +150,12 @@ BinnedPrecisionRecallCurve .. autoclass:: torchmetrics.BinnedPrecisionRecallCurve :noindex: +BinnedRecallAtFixedPrecision +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: torchmetrics.BinnedRecallAtFixedPrecision + :noindex: + CohenKappa ~~~~~~~~~~ @@ -216,12 +222,6 @@ Recall .. autoclass:: torchmetrics.Recall :noindex: -BinnedRecallAtFixedPrecision -~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: torchmetrics.BinnedRecallAtFixedPrecision - :noindex: - ROC ~~~ From c1ae93eb1e606e5266e7629e07e3a80018be2206 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Apr 2021 21:11:29 +0200 Subject: [PATCH 27/33] format --- .../test_binned_precision_recall.py | 25 +++++------ tests/classification/test_f_beta.py | 8 ++-- tests/helpers/testers.py | 22 +++++----- .../classification/binned_precision_recall.py | 43 ++++++++++--------- .../functional/classification/f_beta.py | 2 +- 5 files changed, 50 insertions(+), 50 deletions(-) diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 3b02ea7713c..ba5e56c93df 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -20,6 +20,7 @@ import torch from sklearn.metrics import average_precision_score as _sk_average_precision_score from sklearn.metrics import precision_recall_curve as _sk_precision_recall_curve +from torch import Tensor from tests.classification.inputs import _input_binary_prob from tests.classification.inputs import _input_binary_prob_plausible as _input_binary_prob_ok @@ -32,19 +33,15 @@ seed_all(42) -def recall_at_precision_x_multilabel( - predictions: torch.Tensor, targets: torch.Tensor, min_precision: float -) -> Tuple[float, float]: +def recall_at_precision_x_multilabel(predictions: Tensor, targets: Tensor, min_precision: float) -> Tuple[float, float]: precision, recall, thresholds = _sk_precision_recall_curve( - targets, predictions, + targets, + predictions, ) try: - max_recall, max_precision, best_threshold = max( - (r, p, t) - for p, r, t in zip(precision, recall, thresholds) - if p >= min_precision - ) + max_recall, max_precision, best_threshold = max((r, p, t) for p, r, t in zip(precision, recall, thresholds) + if p >= min_precision) except ValueError: max_recall, best_threshold = 0, 1e6 @@ -63,9 +60,7 @@ def _multiclass_prob_sk_metric(predictions, targets, num_classes, min_precision) def _binary_prob_sk_metric(predictions, targets, num_classes, min_precision): - return recall_at_precision_x_multilabel( - predictions, targets, min_precision - ) + return recall_at_precision_x_multilabel(predictions, targets, min_precision) def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): @@ -93,13 +88,14 @@ def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): ], ) class TestBinnedRecallAtPrecision(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("min_precision", [0.05, 0.1, 0.3, 0.5, 0.8, 0.95]) def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, min_precision): self.atol = 0.02 # rounding will simulate binning for both implementations - preds = torch.Tensor(np.round(preds.numpy(), 2)) + 1e-6 + preds = Tensor(np.round(preds.numpy(), 2)) + 1e-6 self.run_class_metric_test( ddp=ddp, @@ -141,13 +137,14 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_o ], ) class TestBinnedAveragePrecision(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("num_thresholds", [101, 301]) def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, num_thresholds): self.atol = 0.02 # rounding will simulate binning for both implementations - preds = torch.Tensor(np.round(preds.numpy(), 2)) + 1e-6 + preds = Tensor(np.round(preds.numpy(), 2)) + 1e-6 self.run_class_metric_test( ddp=ddp, diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index fbd87408dba..b00758c32ba 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -322,11 +322,11 @@ def test_top_k( metric_class, metric_fn, k: int, - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, average: str, - expected_fbeta: torch.Tensor, - expected_f1: torch.Tensor, + expected_fbeta: Tensor, + expected_f1: Tensor, ): """A simple test to check that top_k works as expected. diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 2fa72fd2df9..348516d247b 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -216,8 +216,8 @@ def _functional_test( def _assert_half_support( metric_module: Metric, metric_functional: Callable, - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, device: str = "cpu", **kwargs_update ): @@ -368,8 +368,8 @@ def run_class_metric_test( device = 'cuda' if (torch.cuda.is_available() and torch.cuda.device_count() > 0) else 'cpu' _class_test( - 0, - 1, + rank=0, + worldsize=1, preds=preds, target=target, metric_class=metric_class, @@ -386,11 +386,11 @@ def run_class_metric_test( def run_precision_test_cpu( self, - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, metric_module: Metric, metric_functional: Callable, - metric_args: dict = {}, + metric_args: dict = None, **kwargs_update, ): """Test if an metric can be used with half precision tensors on cpu @@ -403,17 +403,18 @@ def run_precision_test_cpu( kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ + metric_args = metric_args or {} _assert_half_support( metric_module(**metric_args), metric_functional, preds, target, device="cpu", **kwargs_update ) def run_precision_test_gpu( self, - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, metric_module: Metric, metric_functional: Callable, - metric_args: dict = {}, + metric_args: dict = None, **kwargs_update, ): """Test if an metric can be used with half precision tensors on gpu @@ -426,6 +427,7 @@ def run_precision_test_gpu( kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ + metric_args = metric_args or {} _assert_half_support( metric_module(**metric_args), metric_functional, preds, target, device="cuda", **kwargs_update ) diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 85884f06085..3e324a7f353 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -14,21 +14,17 @@ from typing import Any, List, Optional, Tuple, Union import torch +from torch import Tensor from torchmetrics.functional.classification.average_precision import _average_precision_compute_with_precision_recall from torchmetrics.metric import Metric from torchmetrics.utilities.data import METRIC_EPS, to_onehot -def _recall_at_precision( - precision: torch.Tensor, recall: torch.Tensor, thresholds: torch.Tensor, min_precision: float -): +def _recall_at_precision(precision: Tensor, recall: Tensor, thresholds: Tensor, min_precision: float): try: - max_recall, max_precision, best_threshold = max( - (r, p, t) - for p, r, t in zip(precision, recall, thresholds) - if p >= min_precision - ) + max_recall, max_precision, best_threshold = max((r, p, t) for p, r, t in zip(precision, recall, thresholds) + if p >= min_precision) except ValueError: max_recall = torch.tensor(0.0, device=recall.device, dtype=recall.dtype) @@ -107,10 +103,10 @@ class BinnedPrecisionRecallCurve(Metric): tensor([0.0000, 0.5000, 1.0000])] """ - TPs: torch.Tensor - FPs: torch.Tensor - FNs: torch.Tensor - thresholds: torch.Tensor + TPs: Tensor + FPs: Tensor + FNs: Tensor + thresholds: Tensor def __init__( self, @@ -138,7 +134,7 @@ def __init__( dist_reduce_fx="sum", ) - def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: + def update(self, preds: Tensor, targets: Tensor) -> None: """ Args preds: (n_samples, n_classes) tensor @@ -160,16 +156,19 @@ def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None: self.FPs[:, i] += ((~targets) & (predictions)).sum(dim=0) self.FNs[:, i] += ((targets) & (~predictions)).sum(dim=0) - def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def compute(self) -> Tuple[Tensor, Tensor, Tensor]: """Returns float tensor of size n_classes""" precisions = (self.TPs + METRIC_EPS) / (self.TPs + self.FPs + METRIC_EPS) recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS) # Need to guarantee that last precision=1 and recall=0, similar to precision_recall_curve - precisions = torch.cat([precisions, torch.ones(self.num_classes, 1, - dtype=precisions.dtype, device=precisions.device)], dim=1) - recalls = torch.cat([recalls, torch.zeros(self.num_classes, 1, - dtype=recalls.dtype, device=recalls.device)], dim=1) + precisions = torch.cat([ + precisions, torch.ones(self.num_classes, 1, dtype=precisions.dtype, device=precisions.device) + ], + dim=1) + recalls = torch.cat([recalls, + torch.zeros(self.num_classes, 1, dtype=recalls.dtype, device=recalls.device)], + dim=1) if self.num_classes == 1: return (precisions[0, :], recalls[0, :], self.thresholds) else: @@ -221,7 +220,7 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve): [tensor(1.0000), tensor(1.0000), tensor(0.2500), tensor(0.2500), tensor(-0.)] """ - def compute(self) -> Union[List[torch.Tensor], torch.Tensor]: + def compute(self) -> Union[List[Tensor], Tensor]: precisions, recalls, _ = super().compute() return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes) @@ -269,6 +268,7 @@ class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve): (tensor([1.0000, 1.0000, 0.0000, 0.0000, 0.0000]), tensor([6.6667e-01, 6.6667e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) """ + def __init__( self, num_classes: int, @@ -287,7 +287,7 @@ def __init__( ) self.min_precision = min_precision - def compute(self) -> Tuple[torch.Tensor, torch.Tensor]: + def compute(self) -> Tuple[Tensor, Tensor]: """Returns float tensor of size n_classes""" precisions, recalls, thresholds = super().compute() @@ -298,5 +298,6 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor]: thresholds_at_p = torch.zeros(self.num_classes, device=thresholds[0].device, dtype=thresholds[0].dtype) for i in range(self.num_classes): recalls_at_p[i], thresholds_at_p[i] = _recall_at_precision( - precisions[i], recalls[i], thresholds[i], self.min_precision) + precisions[i], recalls[i], thresholds[i], self.min_precision + ) return (recalls_at_p, thresholds_at_p) diff --git a/torchmetrics/functional/classification/f_beta.py b/torchmetrics/functional/classification/f_beta.py index f0d6cfaa095..6e0ab627a37 100644 --- a/torchmetrics/functional/classification/f_beta.py +++ b/torchmetrics/functional/classification/f_beta.py @@ -22,7 +22,7 @@ from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod -def _safe_divide(num: torch.Tensor, denom: torch.Tensor): +def _safe_divide(num: Tensor, denom: Tensor): """ prevent zero division """ denom[denom == 0.] = 1 return num / denom From 72052f4e25f2a43283d9064091225da3181e4157 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Apr 2021 21:16:35 +0200 Subject: [PATCH 28/33] format --- .../test_binned_precision_recall.py | 2 +- tests/helpers/testers.py | 10 +++++----- tests/regression/test_mean_error.py | 5 +---- tests/regression/test_pearson.py | 16 +++++++--------- 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 188dd6ae2d9..4c5f1bcf835 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -144,7 +144,7 @@ class TestBinnedAveragePrecision(MetricTester): @pytest.mark.parametrize("num_thresholds", [101, 301]) def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step, num_thresholds): # rounding will simulate binning for both implementations - preds = torch.Tensor(np.round(preds.numpy(), 2)) + 1e-6 + preds = Tensor(np.round(preds.numpy(), 2)) + 1e-6 self.run_class_metric_test( ddp=ddp, diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 3e7d72b6ec8..19db9c0b4d9 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -434,11 +434,11 @@ def run_precision_test_gpu( def run_differentiability_test( self, - preds: torch.Tensor, - target: torch.Tensor, + preds: Tensor, + target: Tensor, metric_module: Metric, metric_functional: Callable, - metric_args: dict = {}, + metric_args: dict = None, ): """Test if a metric is differentiable or not @@ -448,6 +448,7 @@ def run_differentiability_test( metric_module: the metric module to test metric_args: dict with additional arguments used for class initialization """ + metric_args = metric_args or {} # only floating point tensors can require grad metric = metric_module(**metric_args) if preds.is_floating_point(): @@ -458,8 +459,7 @@ def run_differentiability_test( if metric.is_differentiable: # check for numerical correctness assert torch.autograd.gradcheck( - partial(metric_functional, **metric_args), - (preds[0].double(), target[0]) + partial(metric_functional, **metric_args), (preds[0].double(), target[0]) ) # reset as else it will carry over to other tests diff --git a/tests/regression/test_mean_error.py b/tests/regression/test_mean_error.py index 3e1832f99d3..7009d4fb71b 100644 --- a/tests/regression/test_mean_error.py +++ b/tests/regression/test_mean_error.py @@ -98,10 +98,7 @@ def test_mean_error_functional(self, preds, target, sk_metric, metric_class, met def test_mean_error_differentiability(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn): self.run_differentiability_test( - preds=preds, - target=target, - metric_module=metric_class, - metric_functional=metric_functional + preds=preds, target=target, metric_module=metric_class, metric_functional=metric_functional ) @pytest.mark.skipif( diff --git a/tests/regression/test_pearson.py b/tests/regression/test_pearson.py index d01952ee38c..91ecc56489e 100644 --- a/tests/regression/test_pearson.py +++ b/tests/regression/test_pearson.py @@ -24,7 +24,6 @@ seed_all(42) - Input = namedtuple('Input', ["preds", "target"]) _single_target_inputs1 = Input( @@ -44,10 +43,12 @@ def _sk_pearsonr(preds, target): return pearsonr(sk_target, sk_preds)[0] -@pytest.mark.parametrize("preds, target", [ - (_single_target_inputs1.preds, _single_target_inputs1.target), - (_single_target_inputs2.preds, _single_target_inputs2.target), -]) +@pytest.mark.parametrize( + "preds, target", [ + (_single_target_inputs1.preds, _single_target_inputs1.target), + (_single_target_inputs2.preds, _single_target_inputs2.target), + ] +) class TestPearsonCorrcoef(MetricTester): atol = 1e-2 @@ -65,10 +66,7 @@ def test_pearson_corrcoef(self, preds, target, ddp, dist_sync_on_step): def test_pearson_corrcoef_functional(self, preds, target): self.run_functional_metric_test( - preds=preds, - target=target, - metric_functional=pearson_corrcoef, - sk_metric=_sk_pearsonr + preds=preds, target=target, metric_functional=pearson_corrcoef, sk_metric=_sk_pearsonr ) # Pearson half + cpu does not work due to missing support in torch.sqrt From 82abbbf3c7452bc4348362d367a14b86c2cb4c1b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Apr 2021 21:16:48 +0200 Subject: [PATCH 29/33] format --- torchmetrics/functional/regression/spearman.py | 5 +---- torchmetrics/regression/pearson.py | 1 + 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/torchmetrics/functional/regression/spearman.py b/torchmetrics/functional/regression/spearman.py index a5f89c7ebc8..b2c6d5f7ddd 100644 --- a/torchmetrics/functional/regression/spearman.py +++ b/torchmetrics/functional/regression/spearman.py @@ -26,10 +26,7 @@ def _find_repeats(data: Tensor): change = torch.cat([torch.tensor([True], device=temp.device), temp[1:] != temp[:-1]]) unique = temp[change] - change_idx = torch.cat([ - torch.nonzero(change), - torch.tensor([[temp.numel()]], device=temp.device) - ]).flatten() + change_idx = torch.cat([torch.nonzero(change), torch.tensor([[temp.numel()]], device=temp.device)]).flatten() freq = change_idx[1:] - change_idx[:-1] atleast2 = freq > 1 return unique[atleast2] diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py index 15efce09330..0cff5368cb3 100644 --- a/torchmetrics/regression/pearson.py +++ b/torchmetrics/regression/pearson.py @@ -55,6 +55,7 @@ class PearsonCorrcoef(Metric): tensor(0.9849) """ + def __init__( self, compute_on_step: bool = True, From 1b6acd6e578404adff8cb8e005c3ef166c53d6a9 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Apr 2021 21:22:12 +0200 Subject: [PATCH 30/33] format --- .../test_binned_precision_recall.py | 56 +++++-------------- 1 file changed, 14 insertions(+), 42 deletions(-) diff --git a/tests/classification/test_binned_precision_recall.py b/tests/classification/test_binned_precision_recall.py index 4c5f1bcf835..20c1a52ec11 100644 --- a/tests/classification/test_binned_precision_recall.py +++ b/tests/classification/test_binned_precision_recall.py @@ -34,21 +34,18 @@ def recall_at_precision_x_multilabel(predictions: Tensor, targets: Tensor, min_precision: float) -> Tuple[float, float]: - precision, recall, thresholds = _sk_precision_recall_curve( - targets, - predictions, - ) + precision, recall, thresholds = _sk_precision_recall_curve(targets, predictions) try: - max_recall, max_precision, best_threshold = max((r, p, t) for p, r, t in zip(precision, recall, thresholds) - if p >= min_precision) + tuple_all = [(r, p, t) for p, r, t in zip(precision, recall, thresholds) if p >= min_precision] + max_recall, max_precision, best_threshold = max(tuple_all) except ValueError: max_recall, best_threshold = 0, 1e6 return float(max_recall), float(best_threshold) -def _multiclass_prob_sk_metric(predictions, targets, num_classes, min_precision): +def _sk_prec_recall_mclass_prob(predictions, targets, num_classes, min_precision): max_recalls = torch.zeros(num_classes) best_thresholds = torch.zeros(num_classes) @@ -59,11 +56,11 @@ def _multiclass_prob_sk_metric(predictions, targets, num_classes, min_precision) return max_recalls, best_thresholds -def _binary_prob_sk_metric(predictions, targets, num_classes, min_precision): +def _sk_prec_recall_binary_prob(predictions, targets, num_classes, min_precision): return recall_at_precision_x_multilabel(predictions, targets, min_precision) -def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): +def _sk_avg_prec_multiclass(predictions, targets, num_classes): # replace nan with 0 return np.nan_to_num(_sk_average_precision_score(targets, predictions, average=None)) @@ -71,20 +68,10 @@ def _multiclass_average_precision_sk_metric(predictions, targets, num_classes): @pytest.mark.parametrize( "preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _binary_prob_sk_metric, 1), - (_input_binary_prob_ok.preds, _input_binary_prob_ok.target, _binary_prob_sk_metric, 1), - ( - _input_mlb_prob_ok.preds, - _input_mlb_prob_ok.target, - _multiclass_prob_sk_metric, - NUM_CLASSES, - ), - ( - _input_mlb_prob.preds, - _input_mlb_prob.target, - _multiclass_prob_sk_metric, - NUM_CLASSES, - ), + (_input_binary_prob.preds, _input_binary_prob.target, _sk_prec_recall_binary_prob, 1), + (_input_binary_prob_ok.preds, _input_binary_prob_ok.target, _sk_prec_recall_binary_prob, 1), + (_input_mlb_prob_ok.preds, _input_mlb_prob_ok.target, _sk_prec_recall_mclass_prob, NUM_CLASSES), + (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_prec_recall_mclass_prob, NUM_CLASSES), ], ) class TestBinnedRecallAtPrecision(MetricTester): @@ -115,25 +102,10 @@ def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, dist_sync_o @pytest.mark.parametrize( "preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _multiclass_average_precision_sk_metric, 1), - ( - _input_binary_prob_ok.preds, - _input_binary_prob_ok.target, - _multiclass_average_precision_sk_metric, - 1, - ), - ( - _input_mlb_prob_ok.preds, - _input_mlb_prob_ok.target, - _multiclass_average_precision_sk_metric, - NUM_CLASSES, - ), - ( - _input_mlb_prob.preds, - _input_mlb_prob.target, - _multiclass_average_precision_sk_metric, - NUM_CLASSES, - ), + (_input_binary_prob.preds, _input_binary_prob.target, _sk_avg_prec_multiclass, 1), + (_input_binary_prob_ok.preds, _input_binary_prob_ok.target, _sk_avg_prec_multiclass, 1), + (_input_mlb_prob_ok.preds, _input_mlb_prob_ok.target, _sk_avg_prec_multiclass, NUM_CLASSES), + (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_avg_prec_multiclass, NUM_CLASSES), ], ) class TestBinnedAveragePrecision(MetricTester): From 31c245a8e6d8f32cf6adaee40f48ecfc8d974e4c Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Tue, 13 Apr 2021 12:42:07 -0700 Subject: [PATCH 31/33] flake8 --- torchmetrics/classification/binned_precision_recall.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 85884f06085..af468b7d4be 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -269,6 +269,7 @@ class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve): (tensor([1.0000, 1.0000, 0.0000, 0.0000, 0.0000]), tensor([6.6667e-01, 6.6667e-01, 1.0000e+06, 1.0000e+06, 1.0000e+06])) """ + def __init__( self, num_classes: int, From 2d540323a888a462ad0ef03b52a1058fc32f4d8a Mon Sep 17 00:00:00 2001 From: Maxim Grechkin Date: Tue, 13 Apr 2021 13:30:55 -0700 Subject: [PATCH 32/33] remove typecheck --- torchmetrics/classification/binned_precision_recall.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/torchmetrics/classification/binned_precision_recall.py b/torchmetrics/classification/binned_precision_recall.py index 3e324a7f353..eff2bd3998c 100644 --- a/torchmetrics/classification/binned_precision_recall.py +++ b/torchmetrics/classification/binned_precision_recall.py @@ -103,11 +103,6 @@ class BinnedPrecisionRecallCurve(Metric): tensor([0.0000, 0.5000, 1.0000])] """ - TPs: Tensor - FPs: Tensor - FNs: Tensor - thresholds: Tensor - def __init__( self, num_classes: int, @@ -165,7 +160,7 @@ def compute(self) -> Tuple[Tensor, Tensor, Tensor]: precisions = torch.cat([ precisions, torch.ones(self.num_classes, 1, dtype=precisions.dtype, device=precisions.device) ], - dim=1) + dim=1) recalls = torch.cat([recalls, torch.zeros(self.num_classes, 1, dtype=recalls.dtype, device=recalls.device)], dim=1) From 469f2056bcf3c82ca15b72afe2898f9719ff339d Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Apr 2021 23:55:24 +0200 Subject: [PATCH 33/33] chlog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fbaf20fbe2..676f13fe8db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Added `PearsonCorrcoef` ([#157](https://github.com/PyTorchLightning/metrics/pull/157)) * Added `SpearmanCorrcoef` ([#158](https://github.com/PyTorchLightning/metrics/pull/158)) * Added `Hinge` ([#120](https://github.com/PyTorchLightning/metrics/pull/120)) +- Added Binned metrics ([#128](https://github.com/PyTorchLightning/metrics/pull/128)) - Added `average='micro'` as an option in AUROC for multilabel problems ([#110](https://github.com/PyTorchLightning/metrics/pull/110)) - Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114)) - Added testing for `half` precision ([#77](https://github.com/PyTorchLightning/metrics/pull/77),