Skip to content

Commit e67148b

Browse files
authored
Merge branch 'master' into bugfix/dtype_auroc
2 parents 2e75fe6 + a6320cf commit e67148b

File tree

4 files changed

+47
-41
lines changed

4 files changed

+47
-41
lines changed

requirements/test.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ requests <=2.28.2
1414
fire <=0.5.0
1515

1616
cloudpickle >1.3, <=2.2.1
17-
scikit-learn >1.0, <1.1.1
17+
scikit-learn >1.0, <1.2.2

src/torchmetrics/classification/precision_recall_curve.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,9 @@ class BinaryPrecisionRecallCurve(Metric):
101101
>>> target = torch.tensor([0, 1, 1, 0])
102102
>>> bprc = BinaryPrecisionRecallCurve(thresholds=None)
103103
>>> bprc(preds, target) # doctest: +NORMALIZE_WHITESPACE
104-
(tensor([0.6667, 0.5000, 0.0000, 1.0000]),
105-
tensor([1.0000, 0.5000, 0.0000, 0.0000]),
106-
tensor([0.5000, 0.7000, 0.8000]))
104+
(tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]),
105+
tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]),
106+
tensor([0.0000, 0.5000, 0.7000, 0.8000]))
107107
>>> bprc = BinaryPrecisionRecallCurve(thresholds=5)
108108
>>> bprc(preds, target) # doctest: +NORMALIZE_WHITESPACE
109109
(tensor([0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000]),
@@ -215,12 +215,13 @@ class MulticlassPrecisionRecallCurve(Metric):
215215
>>> mcprc = MulticlassPrecisionRecallCurve(num_classes=5, thresholds=None)
216216
>>> precision, recall, thresholds = mcprc(preds, target)
217217
>>> precision # doctest: +NORMALIZE_WHITESPACE
218-
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
218+
[tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]),
219219
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
220220
>>> recall
221-
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
221+
[tensor([1., 1., 0.]), tensor([1., 1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
222222
>>> thresholds
223-
[tensor(0.7500), tensor(0.7500), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor(0.0500)]
223+
[tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]),
224+
tensor(0.0500)]
224225
>>> mcprc = MulticlassPrecisionRecallCurve(num_classes=5, thresholds=5)
225226
>>> mcprc(preds, target) # doctest: +NORMALIZE_WHITESPACE
226227
(tensor([[0.2500, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000],
@@ -359,14 +360,13 @@ class MultilabelPrecisionRecallCurve(Metric):
359360
>>> mlprc = MultilabelPrecisionRecallCurve(num_labels=3, thresholds=None)
360361
>>> precision, recall, thresholds = mlprc(preds, target)
361362
>>> precision # doctest: +NORMALIZE_WHITESPACE
362-
[tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.6667, 0.5000, 0.0000, 1.0000]),
363+
[tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]),
363364
tensor([0.7500, 1.0000, 1.0000, 1.0000])]
364365
>>> recall # doctest: +NORMALIZE_WHITESPACE
365-
[tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 0.5000, 0.0000, 0.0000]),
366+
[tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]),
366367
tensor([1.0000, 0.6667, 0.3333, 0.0000])]
367368
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
368-
[tensor([0.0500, 0.4500, 0.7500]), tensor([0.5500, 0.6500, 0.7500]),
369-
tensor([0.0500, 0.3500, 0.7500])]
369+
[tensor([0.0500, 0.4500, 0.7500]), tensor([0.0500, 0.5500, 0.6500, 0.7500]), tensor([0.0500, 0.3500, 0.7500])]
370370
>>> mlprc = MultilabelPrecisionRecallCurve(num_labels=3, thresholds=5)
371371
>>> mlprc(preds, target) # doctest: +NORMALIZE_WHITESPACE
372372
(tensor([[0.5000, 0.5000, 1.0000, 1.0000, 0.0000, 1.0000],
@@ -447,11 +447,11 @@ class PrecisionRecallCurve:
447447
>>> pr_curve = PrecisionRecallCurve(task="binary")
448448
>>> precision, recall, thresholds = pr_curve(pred, target)
449449
>>> precision
450-
tensor([0.6667, 0.5000, 1.0000, 1.0000])
450+
tensor([0.5000, 0.6667, 0.5000, 1.0000, 1.0000])
451451
>>> recall
452-
tensor([1.0000, 0.5000, 0.5000, 0.0000])
452+
tensor([1.0000, 1.0000, 0.5000, 0.5000, 0.0000])
453453
>>> thresholds
454-
tensor([0.1000, 0.4000, 0.8000])
454+
tensor([0.0000, 0.1000, 0.4000, 0.8000])
455455
456456
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
457457
... [0.05, 0.75, 0.05, 0.05, 0.05],
@@ -461,12 +461,13 @@ class PrecisionRecallCurve:
461461
>>> pr_curve = PrecisionRecallCurve(task="multiclass", num_classes=5)
462462
>>> precision, recall, thresholds = pr_curve(pred, target)
463463
>>> precision
464-
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
464+
[tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]),
465465
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
466466
>>> recall
467-
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
467+
[tensor([1., 1., 0.]), tensor([1., 1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
468468
>>> thresholds
469-
[tensor(0.7500), tensor(0.7500), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor(0.0500)]
469+
[tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]),
470+
tensor(0.0500)]
470471
"""
471472

472473
def __new__(

src/torchmetrics/functional/classification/precision_recall_curve.py

+21-24
Original file line numberDiff line numberDiff line change
@@ -267,15 +267,11 @@ def _binary_precision_recall_curve_compute(
267267
precision = tps / (tps + fps)
268268
recall = tps / tps[-1]
269269

270-
# stop when full recall attained and reverse the outputs so recall is decreasing
271-
last_ind = torch.where(tps == tps[-1])[0][0]
272-
sl = slice(0, last_ind.item() + 1)
273-
274270
# need to call reversed explicitly, since including that to slice would
275271
# introduce negative strides that are not yet supported in pytorch
276-
precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)])
277-
recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)])
278-
thresholds = reversed(thresholds[sl]).detach().clone()
272+
precision = torch.cat([precision.flip(0), torch.ones(1, dtype=precision.dtype, device=precision.device)])
273+
recall = torch.cat([recall.flip(0), torch.zeros(1, dtype=recall.dtype, device=recall.device)])
274+
thresholds = thresholds.flip(0).detach().clone()
279275
return precision, recall, thresholds
280276

281277

@@ -338,9 +334,9 @@ def binary_precision_recall_curve(
338334
>>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
339335
>>> target = torch.tensor([0, 1, 1, 0])
340336
>>> binary_precision_recall_curve(preds, target, thresholds=None) # doctest: +NORMALIZE_WHITESPACE
341-
(tensor([0.6667, 0.5000, 0.0000, 1.0000]),
342-
tensor([1.0000, 0.5000, 0.0000, 0.0000]),
343-
tensor([0.5000, 0.7000, 0.8000]))
337+
(tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]),
338+
tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]),
339+
tensor([0.0000, 0.5000, 0.7000, 0.8000]))
344340
>>> binary_precision_recall_curve(preds, target, thresholds=5) # doctest: +NORMALIZE_WHITESPACE
345341
(tensor([0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000]),
346342
tensor([1., 1., 1., 0., 0., 0.]),
@@ -607,12 +603,13 @@ def multiclass_precision_recall_curve(
607603
... preds, target, num_classes=5, thresholds=None
608604
... )
609605
>>> precision # doctest: +NORMALIZE_WHITESPACE
610-
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
606+
[tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]),
611607
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
612608
>>> recall
613-
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
609+
[tensor([1., 1., 0.]), tensor([1., 1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
614610
>>> thresholds
615-
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
611+
[tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]),
612+
tensor([0.0500])]
616613
>>> multiclass_precision_recall_curve(
617614
... preds, target, num_classes=5, thresholds=5
618615
... ) # doctest: +NORMALIZE_WHITESPACE
@@ -837,14 +834,13 @@ def multilabel_precision_recall_curve(
837834
... preds, target, num_labels=3, thresholds=None
838835
... )
839836
>>> precision # doctest: +NORMALIZE_WHITESPACE
840-
[tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.6667, 0.5000, 0.0000, 1.0000]),
837+
[tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.5000, 0.6667, 0.5000, 0.0000, 1.0000]),
841838
tensor([0.7500, 1.0000, 1.0000, 1.0000])]
842839
>>> recall # doctest: +NORMALIZE_WHITESPACE
843-
[tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 0.5000, 0.0000, 0.0000]),
840+
[tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]),
844841
tensor([1.0000, 0.6667, 0.3333, 0.0000])]
845842
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
846-
[tensor([0.0500, 0.4500, 0.7500]), tensor([0.5500, 0.6500, 0.7500]),
847-
tensor([0.0500, 0.3500, 0.7500])]
843+
[tensor([0.0500, 0.4500, 0.7500]), tensor([0.0500, 0.5500, 0.6500, 0.7500]), tensor([0.0500, 0.3500, 0.7500])]
848844
>>> multilabel_precision_recall_curve(
849845
... preds, target, num_labels=3, thresholds=5
850846
... ) # doctest: +NORMALIZE_WHITESPACE
@@ -887,15 +883,15 @@ def precision_recall_curve(
887883
:func:`multilabel_precision_recall_curve` for the specific details of each argument influence and examples.
888884
889885
Legacy Example:
890-
>>> pred = torch.tensor([0.0, 1.0, 2.0, 3.0])
886+
>>> pred = torch.tensor([0, 0.1, 0.8, 0.4])
891887
>>> target = torch.tensor([0, 1, 1, 0])
892888
>>> precision, recall, thresholds = precision_recall_curve(pred, target, task='binary')
893889
>>> precision
894-
tensor([0.6667, 0.5000, 0.0000, 1.0000])
890+
tensor([0.5000, 0.6667, 0.5000, 1.0000, 1.0000])
895891
>>> recall
896-
tensor([1.0000, 0.5000, 0.0000, 0.0000])
892+
tensor([1.0000, 1.0000, 0.5000, 0.5000, 0.0000])
897893
>>> thresholds
898-
tensor([0.7311, 0.8808, 0.9526])
894+
tensor([0.0000, 0.1000, 0.4000, 0.8000])
899895
900896
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
901897
... [0.05, 0.75, 0.05, 0.05, 0.05],
@@ -904,12 +900,13 @@ def precision_recall_curve(
904900
>>> target = torch.tensor([0, 1, 3, 2])
905901
>>> precision, recall, thresholds = precision_recall_curve(pred, target, task='multiclass', num_classes=5)
906902
>>> precision
907-
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
903+
[tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 1.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]),
908904
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
909905
>>> recall
910-
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
906+
[tensor([1., 1., 0.]), tensor([1., 1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
911907
>>> thresholds
912-
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
908+
[tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]),
909+
tensor([0.0500])]
913910
"""
914911
task = ClassificationTask.from_str(task)
915912
if task == ClassificationTask.BINARY:

tests/unittests/classification/test_precision_recall_curve.py

+8
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import operator
1415
from functools import partial
1516

1617
import numpy as np
1718
import pytest
1819
import torch
20+
from lightning_utilities import compare_version
1921
from scipy.special import expit as sigmoid
2022
from scipy.special import softmax
2123
from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve
@@ -53,6 +55,7 @@ class TestBinaryPrecisionRecallCurve(MetricTester):
5355

5456
@pytest.mark.parametrize("ignore_index", [None, -1, 0])
5557
@pytest.mark.parametrize("ddp", [True, False])
58+
@pytest.mark.skipif(compare_version("sklearn", operator.lt, "1.1.0"), reason="Restricted to latest `sklearn`")
5659
def test_binary_precision_recall_curve(self, input, ddp, ignore_index):
5760
"""Test class implementation of metric."""
5861
preds, target = input
@@ -71,6 +74,7 @@ def test_binary_precision_recall_curve(self, input, ddp, ignore_index):
7174
)
7275

7376
@pytest.mark.parametrize("ignore_index", [None, -1, 0])
77+
@pytest.mark.skipif(compare_version("sklearn", operator.lt, "1.1.0"), reason="Restricted to latest `sklearn`")
7478
def test_binary_precision_recall_curve_functional(self, input, ignore_index):
7579
"""Test functional implementation of metric."""
7680
preds, target = input
@@ -178,6 +182,7 @@ class TestMulticlassPrecisionRecallCurve(MetricTester):
178182

179183
@pytest.mark.parametrize("ignore_index", [None, -1])
180184
@pytest.mark.parametrize("ddp", [True, False])
185+
@pytest.mark.skipif(compare_version("sklearn", operator.lt, "1.1.0"), reason="Restricted to latest `sklearn`")
181186
def test_multiclass_precision_recall_curve(self, input, ddp, ignore_index):
182187
"""Test class implementation of metric."""
183188
preds, target = input
@@ -197,6 +202,7 @@ def test_multiclass_precision_recall_curve(self, input, ddp, ignore_index):
197202
)
198203

199204
@pytest.mark.parametrize("ignore_index", [None, -1])
205+
@pytest.mark.skipif(compare_version("sklearn", operator.lt, "1.1.0"), reason="Restricted to latest `sklearn`")
200206
def test_multiclass_precision_recall_curve_functional(self, input, ignore_index):
201207
"""Test functional implementation of metric."""
202208
preds, target = input
@@ -298,6 +304,7 @@ class TestMultilabelPrecisionRecallCurve(MetricTester):
298304

299305
@pytest.mark.parametrize("ignore_index", [None, -1, 0])
300306
@pytest.mark.parametrize("ddp", [True, False])
307+
@pytest.mark.skipif(compare_version("sklearn", operator.lt, "1.1.0"), reason="Restricted to latest `sklearn`")
301308
def test_multilabel_precision_recall_curve(self, input, ddp, ignore_index):
302309
"""Test class implementation of metric."""
303310
preds, target = input
@@ -317,6 +324,7 @@ def test_multilabel_precision_recall_curve(self, input, ddp, ignore_index):
317324
)
318325

319326
@pytest.mark.parametrize("ignore_index", [None, -1, 0])
327+
@pytest.mark.skipif(compare_version("sklearn", operator.lt, "1.1.0"), reason="Restricted to latest `sklearn`")
320328
def test_multilabel_precision_recall_curve_functional(self, input, ignore_index):
321329
"""Test functional implementation of metric."""
322330
preds, target = input

0 commit comments

Comments
 (0)