@@ -267,15 +267,11 @@ def _binary_precision_recall_curve_compute(
267
267
precision = tps / (tps + fps )
268
268
recall = tps / tps [- 1 ]
269
269
270
- # stop when full recall attained and reverse the outputs so recall is decreasing
271
- last_ind = torch .where (tps == tps [- 1 ])[0 ][0 ]
272
- sl = slice (0 , last_ind .item () + 1 )
273
-
274
270
# need to call reversed explicitly, since including that to slice would
275
271
# introduce negative strides that are not yet supported in pytorch
276
- precision = torch .cat ([reversed ( precision [ sl ] ), torch .ones (1 , dtype = precision .dtype , device = precision .device )])
277
- recall = torch .cat ([reversed ( recall [ sl ] ), torch .zeros (1 , dtype = recall .dtype , device = recall .device )])
278
- thresholds = reversed ( thresholds [ sl ] ).detach ().clone ()
272
+ precision = torch .cat ([precision . flip ( 0 ), torch .ones (1 , dtype = precision .dtype , device = precision .device )])
273
+ recall = torch .cat ([recall . flip ( 0 ), torch .zeros (1 , dtype = recall .dtype , device = recall .device )])
274
+ thresholds = thresholds . flip ( 0 ).detach ().clone ()
279
275
return precision , recall , thresholds
280
276
281
277
@@ -338,9 +334,9 @@ def binary_precision_recall_curve(
338
334
>>> preds = torch.tensor([0, 0.5, 0.7, 0.8])
339
335
>>> target = torch.tensor([0, 1, 1, 0])
340
336
>>> binary_precision_recall_curve(preds, target, thresholds=None) # doctest: +NORMALIZE_WHITESPACE
341
- (tensor([0.6667, 0.5000, 0.0000, 1.0000]),
342
- tensor([1.0000, 0.5000, 0.0000, 0.0000]),
343
- tensor([0.5000, 0.7000, 0.8000]))
337
+ (tensor([0.5000, 0. 6667, 0.5000, 0.0000, 1.0000]),
338
+ tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]),
339
+ tensor([0.0000, 0. 5000, 0.7000, 0.8000]))
344
340
>>> binary_precision_recall_curve(preds, target, thresholds=5) # doctest: +NORMALIZE_WHITESPACE
345
341
(tensor([0.5000, 0.6667, 0.6667, 0.0000, 0.0000, 1.0000]),
346
342
tensor([1., 1., 1., 0., 0., 0.]),
@@ -607,12 +603,13 @@ def multiclass_precision_recall_curve(
607
603
... preds, target, num_classes=5, thresholds=None
608
604
... )
609
605
>>> precision # doctest: +NORMALIZE_WHITESPACE
610
- [tensor([1. , 1.]), tensor([1. , 1.]), tensor([0.2500, 0.0000, 1.0000]),
606
+ [tensor([0.2500, 1.0000 , 1.0000 ]), tensor([0.2500, 1.0000 , 1.0000 ]), tensor([0.2500, 0.0000, 1.0000]),
611
607
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
612
608
>>> recall
613
- [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
609
+ [tensor([1., 1., 0.]), tensor([1., 1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
614
610
>>> thresholds
615
- [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
611
+ [tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]),
612
+ tensor([0.0500])]
616
613
>>> multiclass_precision_recall_curve(
617
614
... preds, target, num_classes=5, thresholds=5
618
615
... ) # doctest: +NORMALIZE_WHITESPACE
@@ -837,14 +834,13 @@ def multilabel_precision_recall_curve(
837
834
... preds, target, num_labels=3, thresholds=None
838
835
... )
839
836
>>> precision # doctest: +NORMALIZE_WHITESPACE
840
- [tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.6667, 0.5000, 0.0000, 1.0000]),
837
+ [tensor([0.5000, 0.5000, 1.0000, 1.0000]), tensor([0.5000, 0. 6667, 0.5000, 0.0000, 1.0000]),
841
838
tensor([0.7500, 1.0000, 1.0000, 1.0000])]
842
839
>>> recall # doctest: +NORMALIZE_WHITESPACE
843
- [tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 0.5000, 0.0000, 0.0000]),
840
+ [tensor([1.0000, 0.5000, 0.5000, 0.0000]), tensor([1.0000, 1.0000, 0.5000, 0.0000, 0.0000]),
844
841
tensor([1.0000, 0.6667, 0.3333, 0.0000])]
845
842
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
846
- [tensor([0.0500, 0.4500, 0.7500]), tensor([0.5500, 0.6500, 0.7500]),
847
- tensor([0.0500, 0.3500, 0.7500])]
843
+ [tensor([0.0500, 0.4500, 0.7500]), tensor([0.0500, 0.5500, 0.6500, 0.7500]), tensor([0.0500, 0.3500, 0.7500])]
848
844
>>> multilabel_precision_recall_curve(
849
845
... preds, target, num_labels=3, thresholds=5
850
846
... ) # doctest: +NORMALIZE_WHITESPACE
@@ -887,15 +883,15 @@ def precision_recall_curve(
887
883
:func:`multilabel_precision_recall_curve` for the specific details of each argument influence and examples.
888
884
889
885
Legacy Example:
890
- >>> pred = torch.tensor([0.0, 1.0, 2.0, 3.0 ])
886
+ >>> pred = torch.tensor([0, 0.1, 0.8, 0.4 ])
891
887
>>> target = torch.tensor([0, 1, 1, 0])
892
888
>>> precision, recall, thresholds = precision_recall_curve(pred, target, task='binary')
893
889
>>> precision
894
- tensor([0.6667, 0.5000, 0 .0000, 1.0000])
890
+ tensor([0.5000, 0. 6667, 0.5000, 1 .0000, 1.0000])
895
891
>>> recall
896
- tensor([1.0000, 0.5000, 0.0000 , 0.0000])
892
+ tensor([1.0000, 1.0000, 0.5000, 0.5000 , 0.0000])
897
893
>>> thresholds
898
- tensor([0.7311 , 0.8808 , 0.9526 ])
894
+ tensor([0.0000 , 0.1000 , 0.4000, 0.8000 ])
899
895
900
896
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
901
897
... [0.05, 0.75, 0.05, 0.05, 0.05],
@@ -904,12 +900,13 @@ def precision_recall_curve(
904
900
>>> target = torch.tensor([0, 1, 3, 2])
905
901
>>> precision, recall, thresholds = precision_recall_curve(pred, target, task='multiclass', num_classes=5)
906
902
>>> precision
907
- [tensor([1. , 1.]), tensor([1. , 1.]), tensor([0.2500, 0.0000, 1.0000]),
903
+ [tensor([0.2500, 1.0000 , 1.0000 ]), tensor([0.2500, 1.0000 , 1.0000 ]), tensor([0.2500, 0.0000, 1.0000]),
908
904
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
909
905
>>> recall
910
- [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
906
+ [tensor([1., 1., 0.]), tensor([1., 1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
911
907
>>> thresholds
912
- [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
908
+ [tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]),
909
+ tensor([0.0500])]
913
910
"""
914
911
task = ClassificationTask .from_str (task )
915
912
if task == ClassificationTask .BINARY :
0 commit comments