From ef3e47305ef9309e1dab0a64f5379d5fd7135a0f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Fri, 15 Mar 2024 15:03:05 +0100 Subject: [PATCH] Fix how auc scores are calculated in `PrecisionRecallCurve.plot` methods (#2437) Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> (cherry picked from commit 0a6ad011cb46df90cff57aef6510bc475cfeaf25) --- CHANGELOG.md | 2 ++ .../classification/precision_recall_curve.py | 19 +++++++++++++------ src/torchmetrics/classification/roc.py | 9 ++++++--- .../functional/classification/auroc.py | 5 +++-- src/torchmetrics/utilities/compute.py | 4 ++-- 5 files changed, 26 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 12fcc66e178..834231eb8cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed bug when `top_k>1` and `average="macro"` for classification metrics ([#2423](https://github.com/Lightning-AI/torchmetrics/pull/2423)) +- Fixed how auc scores are calculated in `PrecisionRecallCurve.plot` methods ([#2437](https://github.com/Lightning-AI/torchmetrics/pull/2437)) + ## [1.3.1] - 2024-02-12 ### Fixed diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 46b874a74b2..366f11710d4 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -188,7 +188,8 @@ def plot( curve: the output of either `metric.compute` or `metric.forward`. If no value is provided, will automatically call `metric.compute` and plot that result. score: Provide a area-under-the-curve score to be displayed on the plot. If `True` and no curve is provided, - will automatically compute the score. + will automatically compute the score. The score is computed by using the trapezoidal rule to compute the + area under the curve. ax: An matplotlib axis object. If provided will add plot to that axis Returns: @@ -215,7 +216,7 @@ def plot( curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2]) score = ( - _auc_compute_without_check(curve_computed[0], curve_computed[1], 1.0) + _auc_compute_without_check(curve_computed[0], curve_computed[1], direction=-1.0) if not curve and score is True else None ) @@ -390,7 +391,8 @@ def plot( curve: the output of either `metric.compute` or `metric.forward`. If no value is provided, will automatically call `metric.compute` and plot that result. score: Provide a area-under-the-curve score to be displayed on the plot. If `True` and no curve is provided, - will automatically compute the score. + will automatically compute the score. The score is computed by using the trapezoidal rule to compute the + area under the curve. ax: An matplotlib axis object. If provided will add plot to that axis Returns: @@ -416,7 +418,9 @@ def plot( # switch order as the standard way is recall along x-axis and precision along y-axis curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2]) score = ( - _reduce_auroc(curve_computed[0], curve_computed[1], average=None) if not curve and score is True else None + _reduce_auroc(curve_computed[0], curve_computed[1], average=None, direction=-1.0) + if not curve and score is True + else None ) return plot_curve( curve_computed, score=score, ax=ax, label_names=("Recall", "Precision"), name=self.__class__.__name__ @@ -583,7 +587,8 @@ def plot( curve: the output of either `metric.compute` or `metric.forward`. If no value is provided, will automatically call `metric.compute` and plot that result. score: Provide a area-under-the-curve score to be displayed on the plot. If `True` and no curve is provided, - will automatically compute the score. + will automatically compute the score. The score is computed by using the trapezoidal rule to compute the + area under the curve. ax: An matplotlib axis object. If provided will add plot to that axis Returns: @@ -609,7 +614,9 @@ def plot( # switch order as the standard way is recall along x-axis and precision along y-axis curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2]) score = ( - _reduce_auroc(curve_computed[0], curve_computed[1], average=None) if not curve and score is True else None + _reduce_auroc(curve_computed[0], curve_computed[1], average=None, direction=-1.0) + if not curve and score is True + else None ) return plot_curve( curve_computed, score=score, ax=ax, label_names=("Recall", "Precision"), name=self.__class__.__name__ diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 7f1479a1ae6..40bd8c36327 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -134,7 +134,8 @@ def plot( curve: the output of either `metric.compute` or `metric.forward`. If no value is provided, will automatically call `metric.compute` and plot that result. score: Provide a area-under-the-curve score to be displayed on the plot. If `True` and no curve is provided, - will automatically compute the score. + will automatically compute the score. The score is computed by using the trapezoidal rule to compute the + area under the curve. ax: An matplotlib axis object. If provided will add plot to that axis Returns: @@ -303,7 +304,8 @@ def plot( curve: the output of either `metric.compute` or `metric.forward`. If no value is provided, will automatically call `metric.compute` and plot that result. score: Provide a area-under-the-curve score to be displayed on the plot. If `True` and no curve is provided, - will automatically compute the score. + will automatically compute the score. The score is computed by using the trapezoidal rule to compute the + area under the curve. ax: An matplotlib axis object. If provided will add plot to that axis Returns: @@ -461,7 +463,8 @@ def plot( curve: the output of either `metric.compute` or `metric.forward`. If no value is provided, will automatically call `metric.compute` and plot that result. score: Provide a area-under-the-curve score to be displayed on the plot. If `True` and no curve is provided, - will automatically compute the score. + will automatically compute the score. The score is computed by using the trapezoidal rule to compute the + area under the curve. ax: An matplotlib axis object. If provided will add plot to that axis Returns: diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index acd94f4050e..fb802c05ec3 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -47,12 +47,13 @@ def _reduce_auroc( tpr: Union[Tensor, List[Tensor]], average: Optional[Literal["macro", "weighted", "none"]] = "macro", weights: Optional[Tensor] = None, + direction: float = 1.0, ) -> Tensor: """Reduce multiple average precision score into one number.""" if isinstance(fpr, Tensor) and isinstance(tpr, Tensor): - res = _auc_compute_without_check(fpr, tpr, 1.0, axis=1) + res = _auc_compute_without_check(fpr, tpr, direction=direction, axis=1) else: - res = torch.stack([_auc_compute_without_check(x, y, 1.0) for x, y in zip(fpr, tpr)]) + res = torch.stack([_auc_compute_without_check(x, y, direction=direction) for x, y in zip(fpr, tpr)]) if average is None or average == "none": return res if torch.isnan(res).any(): diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 9ff82ce987f..12613103ca6 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -92,8 +92,8 @@ def _auc_compute_without_check(x: Tensor, y: Tensor, direction: float, axis: int """ with torch.no_grad(): - auc_: Tensor = torch.trapz(y, x, dim=axis) * direction - return auc_ + auc_score: Tensor = torch.trapz(y, x, dim=axis) * direction + return auc_score def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: