Skip to content

Commit 7cda67c

Browse files
authored
Fix missing AUC score when plotting (#1948)
1 parent fb74428 commit 7cda67c

File tree

3 files changed

+79
-60
lines changed

3 files changed

+79
-60
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4141
- Fixed bug related to the `prefix/postfix` arguments in `MetricCollection` and `ClasswiseWrapper` being duplicated ([#1918](https://github.com/Lightning-AI/torchmetrics/pull/1918))
4242

4343

44+
- Fixed missing AUC score when plotting classification metrics that support the `score` argument ([#1948](https://github.com/Lightning-AI/torchmetrics/pull/1948))
45+
46+
4447
## [1.0.1] - 2023-07-13
4548

4649
### Fixed

src/torchmetrics/classification/precision_recall_curve.py

+42-31
Original file line numberDiff line numberDiff line change
@@ -197,19 +197,24 @@ def plot(
197197
.. plot::
198198
:scale: 75
199199
200-
>>> from torch import randn, randint
201-
>>> import torch.nn.functional as F
202-
>>> from torchmetrics.classification import BinaryROC
203-
>>> preds = F.softmax(randn(20, 2), dim=1)
200+
>>> from torch import rand, randint
201+
>>> from torchmetrics.classification import BinaryPrecisionRecallCurve
202+
>>> preds = rand(20)
204203
>>> target = randint(2, (20,))
205-
>>> metric = BinaryROC()
206-
>>> metric.update(preds[:, 1], target)
207-
>>> fig_, ax_ = metric.plot()
204+
>>> metric = BinaryPrecisionRecallCurve()
205+
>>> metric.update(preds, target)
206+
>>> fig_, ax_ = metric.plot(score=True)
208207
209208
"""
210-
curve = curve or self.compute()
211-
score = _auc_compute_without_check(curve[0], curve[1], 1.0) if not curve and score is True else None
212-
return plot_curve(curve, score=score, ax=ax, label_names=("Precision", "Recall"), name=self.__class__.__name__)
209+
curve_computed = curve or self.compute()
210+
score = (
211+
_auc_compute_without_check(curve_computed[0], curve_computed[1], 1.0)
212+
if not curve and score is True
213+
else None
214+
)
215+
return plot_curve(
216+
curve_computed, score=score, ax=ax, label_names=("Precision", "Recall"), name=self.__class__.__name__
217+
)
213218

214219

215220
class MulticlassPrecisionRecallCurve(Metric):
@@ -374,18 +379,21 @@ def plot(
374379
:scale: 75
375380
376381
>>> from torch import randn, randint
377-
>>> import torch.nn.functional as F
378-
>>> from torchmetrics.classification import BinaryROC
379-
>>> preds = F.softmax(randn(20, 2), dim=1)
380-
>>> target = randint(2, (20,))
381-
>>> metric = BinaryROC()
382-
>>> metric.update(preds[:, 1], target)
383-
>>> fig_, ax_ = metric.plot()
382+
>>> from torchmetrics.classification import MulticlassPrecisionRecallCurve
383+
>>> preds = randn(20, 3).softmax(dim=-1)
384+
>>> target = randint(3, (20,))
385+
>>> metric = MulticlassPrecisionRecallCurve(num_classes=3)
386+
>>> metric.update(preds, target)
387+
>>> fig_, ax_ = metric.plot(score=True)
384388
385389
"""
386-
curve = curve or self.compute()
387-
score = _reduce_auroc(curve[0], curve[1], average=None) if not curve and score is True else None
388-
return plot_curve(curve, score=score, ax=ax, label_names=("Precision", "Recall"), name=self.__class__.__name__)
390+
curve_computed = curve or self.compute()
391+
score = (
392+
_reduce_auroc(curve_computed[0], curve_computed[1], average=None) if not curve and score is True else None
393+
)
394+
return plot_curve(
395+
curve_computed, score=score, ax=ax, label_names=("Precision", "Recall"), name=self.__class__.__name__
396+
)
389397

390398

391399
class MultilabelPrecisionRecallCurve(Metric):
@@ -558,19 +566,22 @@ def plot(
558566
.. plot::
559567
:scale: 75
560568
561-
>>> from torch import randn, randint
562-
>>> import torch.nn.functional as F
563-
>>> from torchmetrics.classification import BinaryROC
564-
>>> preds = F.softmax(randn(20, 2), dim=1)
565-
>>> target = randint(2, (20,))
566-
>>> metric = BinaryROC()
567-
>>> metric.update(preds[:, 1], target)
568-
>>> fig_, ax_ = metric.plot()
569+
>>> from torch import rand, randint
570+
>>> from torchmetrics.classification import MultilabelPrecisionRecallCurve
571+
>>> preds = rand(20, 3)
572+
>>> target = randint(2, (20,3))
573+
>>> metric = MultilabelPrecisionRecallCurve(num_labels=3)
574+
>>> metric.update(preds, target)
575+
>>> fig_, ax_ = metric.plot(score=True)
569576
570577
"""
571-
curve = curve or self.compute()
572-
score = _reduce_auroc(curve[0], curve[1], average=None) if not curve and score is True else None
573-
return plot_curve(curve, score=score, ax=ax, label_names=("Precision", "Recall"), name=self.__class__.__name__)
578+
curve_computed = curve or self.compute()
579+
score = (
580+
_reduce_auroc(curve_computed[0], curve_computed[1], average=None) if not curve and score is True else None
581+
)
582+
return plot_curve(
583+
curve_computed, score=score, ax=ax, label_names=("Precision", "Recall"), name=self.__class__.__name__
584+
)
574585

575586

576587
class PrecisionRecallCurve:

src/torchmetrics/classification/roc.py

+34-29
Original file line numberDiff line numberDiff line change
@@ -143,20 +143,23 @@ def plot(
143143
.. plot::
144144
:scale: 75
145145
146-
>>> from torch import randn, randint
147-
>>> import torch.nn.functional as F
146+
>>> from torch import rand, randint
148147
>>> from torchmetrics.classification import BinaryROC
149-
>>> preds = F.softmax(randn(20, 2), dim=1)
148+
>>> preds = rand(20)
150149
>>> target = randint(2, (20,))
151150
>>> metric = BinaryROC()
152-
>>> metric.update(preds[:, 1], target)
153-
>>> fig_, ax_ = metric.plot()
151+
>>> metric.update(preds, target)
152+
>>> fig_, ax_ = metric.plot(score=True)
154153
155154
"""
156-
curve = curve or self.compute()
157-
score = _auc_compute_without_check(curve[0], curve[1], 1.0) if not curve and score is True else None
155+
curve_computed = curve or self.compute()
156+
score = (
157+
_auc_compute_without_check(curve_computed[0], curve_computed[1], 1.0)
158+
if not curve and score is True
159+
else None
160+
)
158161
return plot_curve(
159-
curve,
162+
curve_computed,
160163
score=score,
161164
ax=ax,
162165
label_names=("False positive rate", "True positive rate"),
@@ -296,19 +299,20 @@ def plot(
296299
:scale: 75
297300
298301
>>> from torch import randn, randint
299-
>>> import torch.nn.functional as F
300-
>>> from torchmetrics.classification import BinaryROC
301-
>>> preds = F.softmax(randn(20, 2), dim=1)
302-
>>> target = randint(2, (20,))
303-
>>> metric = BinaryROC()
304-
>>> metric.update(preds[:, 1], target)
305-
>>> fig_, ax_ = metric.plot()
302+
>>> from torchmetrics.classification import MulticlassROC
303+
>>> preds = randn(20, 3).softmax(dim=-1)
304+
>>> target = randint(3, (20,))
305+
>>> metric = MulticlassROC(num_classes=3)
306+
>>> metric.update(preds, target)
307+
>>> fig_, ax_ = metric.plot(score=True)
306308
307309
"""
308-
curve = curve or self.compute()
309-
score = _reduce_auroc(curve[0], curve[1], average=None) if not curve and score is True else None
310+
curve_computed = curve or self.compute()
311+
score = (
312+
_reduce_auroc(curve_computed[0], curve_computed[1], average=None) if not curve and score is True else None
313+
)
310314
return plot_curve(
311-
curve,
315+
curve_computed,
312316
score=score,
313317
ax=ax,
314318
label_names=("False positive rate", "True positive rate"),
@@ -449,20 +453,21 @@ def plot(
449453
.. plot::
450454
:scale: 75
451455
452-
>>> from torch import randn, randint
453-
>>> import torch.nn.functional as F
454-
>>> from torchmetrics.classification import BinaryROC
455-
>>> preds = F.softmax(randn(20, 2), dim=1)
456-
>>> target = randint(2, (20,))
457-
>>> metric = BinaryROC()
458-
>>> metric.update(preds[:, 1], target)
459-
>>> fig_, ax_ = metric.plot()
456+
>>> from torch import rand, randint
457+
>>> from torchmetrics.classification import MultilabelROC
458+
>>> preds = rand(20, 3)
459+
>>> target = randint(2, (20,3))
460+
>>> metric = MultilabelROC(num_labels=3)
461+
>>> metric.update(preds, target)
462+
>>> fig_, ax_ = metric.plot(score=True)
460463
461464
"""
462-
curve = curve or self.compute()
463-
score = _reduce_auroc(curve[0], curve[1], average=None) if not curve and score is True else None
465+
curve_computed = curve or self.compute()
466+
score = (
467+
_reduce_auroc(curve_computed[0], curve_computed[1], average=None) if not curve and score is True else None
468+
)
464469
return plot_curve(
465-
curve,
470+
curve_computed,
466471
score=score,
467472
ax=ax,
468473
label_names=("False positive rate", "True positive rate"),

0 commit comments

Comments
 (0)