@@ -197,19 +197,24 @@ def plot(
197
197
.. plot::
198
198
:scale: 75
199
199
200
- >>> from torch import randn, randint
201
- >>> import torch.nn.functional as F
202
- >>> from torchmetrics.classification import BinaryROC
203
- >>> preds = F.softmax(randn(20, 2), dim=1)
200
+ >>> from torch import rand, randint
201
+ >>> from torchmetrics.classification import BinaryPrecisionRecallCurve
202
+ >>> preds = rand(20)
204
203
>>> target = randint(2, (20,))
205
- >>> metric = BinaryROC ()
206
- >>> metric.update(preds[:, 1] , target)
207
- >>> fig_, ax_ = metric.plot()
204
+ >>> metric = BinaryPrecisionRecallCurve ()
205
+ >>> metric.update(preds, target)
206
+ >>> fig_, ax_ = metric.plot(score=True )
208
207
209
208
"""
210
- curve = curve or self .compute ()
211
- score = _auc_compute_without_check (curve [0 ], curve [1 ], 1.0 ) if not curve and score is True else None
212
- return plot_curve (curve , score = score , ax = ax , label_names = ("Precision" , "Recall" ), name = self .__class__ .__name__ )
209
+ curve_computed = curve or self .compute ()
210
+ score = (
211
+ _auc_compute_without_check (curve_computed [0 ], curve_computed [1 ], 1.0 )
212
+ if not curve and score is True
213
+ else None
214
+ )
215
+ return plot_curve (
216
+ curve_computed , score = score , ax = ax , label_names = ("Precision" , "Recall" ), name = self .__class__ .__name__
217
+ )
213
218
214
219
215
220
class MulticlassPrecisionRecallCurve (Metric ):
@@ -374,18 +379,21 @@ def plot(
374
379
:scale: 75
375
380
376
381
>>> from torch import randn, randint
377
- >>> import torch.nn.functional as F
378
- >>> from torchmetrics.classification import BinaryROC
379
- >>> preds = F.softmax(randn(20, 2), dim=1)
380
- >>> target = randint(2, (20,))
381
- >>> metric = BinaryROC()
382
- >>> metric.update(preds[:, 1], target)
383
- >>> fig_, ax_ = metric.plot()
382
+ >>> from torchmetrics.classification import MulticlassPrecisionRecallCurve
383
+ >>> preds = randn(20, 3).softmax(dim=-1)
384
+ >>> target = randint(3, (20,))
385
+ >>> metric = MulticlassPrecisionRecallCurve(num_classes=3)
386
+ >>> metric.update(preds, target)
387
+ >>> fig_, ax_ = metric.plot(score=True)
384
388
385
389
"""
386
- curve = curve or self .compute ()
387
- score = _reduce_auroc (curve [0 ], curve [1 ], average = None ) if not curve and score is True else None
388
- return plot_curve (curve , score = score , ax = ax , label_names = ("Precision" , "Recall" ), name = self .__class__ .__name__ )
390
+ curve_computed = curve or self .compute ()
391
+ score = (
392
+ _reduce_auroc (curve_computed [0 ], curve_computed [1 ], average = None ) if not curve and score is True else None
393
+ )
394
+ return plot_curve (
395
+ curve_computed , score = score , ax = ax , label_names = ("Precision" , "Recall" ), name = self .__class__ .__name__
396
+ )
389
397
390
398
391
399
class MultilabelPrecisionRecallCurve (Metric ):
@@ -558,19 +566,22 @@ def plot(
558
566
.. plot::
559
567
:scale: 75
560
568
561
- >>> from torch import randn, randint
562
- >>> import torch.nn.functional as F
563
- >>> from torchmetrics.classification import BinaryROC
564
- >>> preds = F.softmax(randn(20, 2), dim=1)
565
- >>> target = randint(2, (20,))
566
- >>> metric = BinaryROC()
567
- >>> metric.update(preds[:, 1], target)
568
- >>> fig_, ax_ = metric.plot()
569
+ >>> from torch import rand, randint
570
+ >>> from torchmetrics.classification import MultilabelPrecisionRecallCurve
571
+ >>> preds = rand(20, 3)
572
+ >>> target = randint(2, (20,3))
573
+ >>> metric = MultilabelPrecisionRecallCurve(num_labels=3)
574
+ >>> metric.update(preds, target)
575
+ >>> fig_, ax_ = metric.plot(score=True)
569
576
570
577
"""
571
- curve = curve or self .compute ()
572
- score = _reduce_auroc (curve [0 ], curve [1 ], average = None ) if not curve and score is True else None
573
- return plot_curve (curve , score = score , ax = ax , label_names = ("Precision" , "Recall" ), name = self .__class__ .__name__ )
578
+ curve_computed = curve or self .compute ()
579
+ score = (
580
+ _reduce_auroc (curve_computed [0 ], curve_computed [1 ], average = None ) if not curve and score is True else None
581
+ )
582
+ return plot_curve (
583
+ curve_computed , score = score , ax = ax , label_names = ("Precision" , "Recall" ), name = self .__class__ .__name__
584
+ )
574
585
575
586
576
587
class PrecisionRecallCurve :
0 commit comments