From 38a2119359f22dd9525cb7978eb2ac230a36ab59 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 18 Mar 2021 18:21:59 +0100 Subject: [PATCH] Prune metrics: precision & recall 6/n (#6573) * avg precision * precision * recall * curve * tests * chlog * isort * fix --- CHANGELOG.md | 2 + .../classification/average_precision.py | 103 +--- .../classification/precision_recall.py | 275 +--------- .../classification/precision_recall_curve.py | 131 +---- .../metrics/functional/average_precision.py | 78 +-- .../metrics/functional/precision_recall.py | 477 +----------------- .../functional/precision_recall_curve.py | 204 +------- tests/deprecated_api/test_remove_1-4.py | 10 - .../classification/test_average_precision.py | 97 ---- .../classification/test_precision_recall.py | 348 ------------- .../test_precision_recall_curve.py | 97 ---- .../metrics/functional/test_classification.py | 26 - tests/metrics/test_remove_1-5_metrics.py | 72 ++- 13 files changed, 127 insertions(+), 1793 deletions(-) delete mode 100644 tests/metrics/classification/test_average_precision.py delete mode 100644 tests/metrics/classification/test_precision_recall.py delete mode 100644 tests/metrics/classification/test_precision_recall_curve.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d397079bb072..2b0629a4194d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -79,6 +79,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#6572](https://github.com/PyTorchLightning/pytorch-lightning/pull/6572), + [#6573](https://github.com/PyTorchLightning/pytorch-lightning/pull/6573), + ) diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py index adcdd86ed1ca8..106d6ea6111b2 100644 --- a/pytorch_lightning/metrics/classification/average_precision.py +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -11,64 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Union +from typing import Any, Optional -import torch -from torchmetrics import Metric +from torchmetrics import AveragePrecision as _AveragePrecision -from pytorch_lightning.metrics.functional.average_precision import _average_precision_compute, _average_precision_update -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.deprecation import deprecated -class AveragePrecision(Metric): - """ - Computes the average precision score, which summarises the precision recall - curve into one number. Works for both binary and multiclass problems. - In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. - - Forward accepts - - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor - with probabilities, where C is the number of classes. - - - ``target`` (long tensor): ``(N, ...)`` with integer labels - - Args: - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example (binary case): - - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 1]) - >>> average_precision = AveragePrecision(pos_label=1) - >>> average_precision(pred, target) - tensor(1.) - - Example (multiclass case): - - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.05, 0.75, 0.05]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> average_precision = AveragePrecision(num_classes=5) - >>> average_precision(pred, target) - [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] - - """ +class AveragePrecision(_AveragePrecision): + @deprecated(target=_AveragePrecision, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, num_classes: Optional[int] = None, @@ -77,48 +29,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - - self.num_classes = num_classes - self.pos_label = pos_label - - self.add_state("preds", default=[], dist_reduce_fx=None) - self.add_state("target", default=[], dist_reduce_fx=None) - - rank_zero_warn( - 'Metric `AveragePrecision` will save all targets and predictions in buffer.' - ' For large datasets this may lead to large memory footprint.' - ) - - def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - preds, target, num_classes, pos_label = _average_precision_update( - preds, target, self.num_classes, self.pos_label - ) - self.preds.append(preds) - self.target.append(target) - self.num_classes = num_classes - self.pos_label = pos_label - - def compute(self) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - Compute the average precision score - - Returns: - tensor with average precision. If multiclass will return list - of such tensors, one for each class + This implementation refers to :class:`~torchmetrics.AveragePrecision`. + .. deprecated:: + Use :class:`~torchmetrics.AveragePrecision`. Will be removed in v1.5.0. """ - preds = torch.cat(self.preds, dim=0) - target = torch.cat(self.target, dim=0) - return _average_precision_compute(preds, target, self.num_classes, self.pos_label) diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 5a163097ee0bc..ae3ee40da0ca5 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -13,112 +13,15 @@ # limitations under the License. from typing import Any, Callable, Optional -import torch +from torchmetrics import Precision as _Precision +from torchmetrics import Recall as _Recall -from pytorch_lightning.metrics.classification.stat_scores import StatScores -from pytorch_lightning.metrics.functional.precision_recall import _precision_compute, _recall_compute +from pytorch_lightning.utilities.deprecation import deprecated -class Precision(StatScores): - r""" - Computes `Precision `_: - - .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} - - Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and - false positives respecitively. With the use of ``top_k`` parameter, this metric can - generalize to Precision@K. - - The reduction method (how the precision scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. - - Args: - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics accross classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics accross classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - multilabel: - .. warning :: This parameter is deprecated and has no effect. Will be removed in v1.4.0. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. - - Should be left unset (``None``) for inputs with label predictions. - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - compute_on_step: - Forward only calls ``update()`` and return ``None`` if this is set to ``False``. - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step - process_group: - Specify the process group on which synchronization is called. - default: ``None`` (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the allgather. - - Raises: - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. - - Example: - - >>> from pytorch_lightning.metrics import Precision - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> precision = Precision(average='macro', num_classes=3) - >>> precision(preds, target) - tensor(0.1667) - >>> precision = Precision(average='micro') - >>> precision(preds, target) - tensor(0.2500) - - """ +class Precision(_Precision): + @deprecated(target=_Precision, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, num_classes: Optional[int] = None, @@ -134,142 +37,17 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - super().__init__( - reduce="macro" if average in ["weighted", "none", None] else average, - mdmc_reduce=mdmc_average, - threshold=threshold, - top_k=top_k, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.average = average - - def compute(self) -> torch.Tensor: """ - Computes the precision score based on inputs passed in to ``update`` previously. + This implementation refers to :class:`~torchmetrics.Precision`. - Return: - The shape of the returned tensor depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes + .. deprecated:: + Use :class:`~torchmetrics.Precision`. Will be removed in v1.5.0. """ - tp, fp, tn, fn = self._get_final_stats() - return _precision_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) - -class Recall(StatScores): - r""" - Computes `Recall `_: - .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} - - Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and - false negatives respecitively. With the use of ``top_k`` parameter, this metric can - generalize to Recall@K. - - The reduction method (how the recall scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. - - Args: - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics accross classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics accross classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - multilabel: - .. warning :: This parameter is deprecated and has no effect. Will be removed in v1.4.0. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. - - Should be left unset (``None``) for inputs with label predictions. - - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - compute_on_step: - Forward only calls ``update()`` and return ``None`` if this is set to ``False``. - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step - process_group: - Specify the process group on which synchronization is called. - default: ``None`` (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When ``None``, DDP - will be used to perform the allgather. - - Raises: - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. - - Example: - - >>> from pytorch_lightning.metrics import Recall - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> recall = Recall(average='macro', num_classes=3) - >>> recall(preds, target) - tensor(0.3333) - >>> recall = Recall(average='micro') - >>> recall(preds, target) - tensor(0.2500) - - """ +class Recall(_Recall): + @deprecated(target=_Recall, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, num_classes: Optional[int] = None, @@ -285,36 +63,9 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - super().__init__( - reduce="macro" if average in ["weighted", "none", None] else average, - mdmc_reduce=mdmc_average, - threshold=threshold, - top_k=top_k, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn, - ) - - self.average = average - - def compute(self) -> torch.Tensor: """ - Computes the recall score based on inputs passed in to ``update`` previously. - - Return: - The shape of the returned tensor depends on the ``average`` parameter + This implementation refers to :class:`~torchmetrics.Recall`. - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes + .. deprecated:: + Use :class:`~torchmetrics.Recall`. Will be removed in v1.5.0. """ - tp, fp, tn, fn = self._get_final_stats() - return _recall_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py index ccf821d829d78..fb8f6a812028c 100644 --- a/pytorch_lightning/metrics/classification/precision_recall_curve.py +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -11,80 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional -import torch -from torchmetrics import Metric +from torchmetrics import PrecisionRecallCurve as _PrecisionRecallCurve -from pytorch_lightning.metrics.functional.precision_recall_curve import ( - _precision_recall_curve_compute, - _precision_recall_curve_update, -) -from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.deprecation import deprecated -class PrecisionRecallCurve(Metric): - """ - Computes precision-recall pairs for different thresholds. Works for both - binary and multiclass problems. In the case of multiclass, the values will - be calculated based on a one-vs-the-rest approach. - - Forward accepts - - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor - with probabilities, where C is the number of classes. - - - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels - - Args: - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - - Example (binary case): - - >>> from pytorch_lightning.metrics import PrecisionRecallCurve - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 0]) - >>> pr_curve = PrecisionRecallCurve(pos_label=1) - >>> precision, recall, thresholds = pr_curve(pred, target) - >>> precision - tensor([0.6667, 0.5000, 0.0000, 1.0000]) - >>> recall - tensor([1.0000, 0.5000, 0.0000, 0.0000]) - >>> thresholds - tensor([1, 2, 3]) - - Example (multiclass case): - - >>> from pytorch_lightning.metrics import PrecisionRecallCurve - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.05, 0.75, 0.05]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> pr_curve = PrecisionRecallCurve(num_classes=5) - >>> precision, recall, thresholds = pr_curve(pred, target) - >>> precision # doctest: +NORMALIZE_WHITESPACE - [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), - tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] - >>> recall - [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] - >>> thresholds - [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] - - """ +class PrecisionRecallCurve(_PrecisionRecallCurve): + @deprecated(target=_PrecisionRecallCurve, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__( self, num_classes: Optional[int] = None, @@ -93,60 +29,9 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - ) - - self.num_classes = num_classes - self.pos_label = pos_label - - self.add_state("preds", default=[], dist_reduce_fx=None) - self.add_state("target", default=[], dist_reduce_fx=None) - - rank_zero_warn( - 'Metric `PrecisionRecallCurve` will save all targets and predictions in buffer.' - ' For large datasets this may lead to large memory footprint.' - ) - - def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - preds, target, num_classes, pos_label = _precision_recall_curve_update( - preds, target, self.num_classes, self.pos_label - ) - self.preds.append(preds) - self.target.append(target) - self.num_classes = num_classes - self.pos_label = pos_label - - def compute( - self - ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: - """ - Compute the precision-recall curve - - Returns: - 3-element tuple containing + This implementation refers to :class:`~torchmetrics.PrecisionRecallCurve`. - precision: - tensor where element i is the precision of predictions with - score >= thresholds[i] and the last element is 1. - If multiclass, this is a list of such tensors, one for each class. - recall: - tensor where element i is the recall of predictions with - score >= thresholds[i] and the last element is 0. - If multiclass, this is a list of such tensors, one for each class. - thresholds: - Thresholds used for computing precision/recall scores + .. deprecated:: + Use :class:`~torchmetrics.PrecisionRecallCurve`. Will be removed in v1.5.0. """ - preds = torch.cat(self.preds, dim=0) - target = torch.cat(self.target, dim=0) - return _precision_recall_curve_compute(preds, target, self.num_classes, self.pos_label) diff --git a/pytorch_lightning/metrics/functional/average_precision.py b/pytorch_lightning/metrics/functional/average_precision.py index 2a82c4f38f20e..e4ce3941fe008 100644 --- a/pytorch_lightning/metrics/functional/average_precision.py +++ b/pytorch_lightning/metrics/functional/average_precision.py @@ -11,45 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Sequence, Union import torch +from torchmetrics.functional import average_precision as _average_precision -from pytorch_lightning.metrics.functional.precision_recall_curve import ( - _precision_recall_curve_compute, - _precision_recall_curve_update, -) - - -def _average_precision_update( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, int, int]: - return _precision_recall_curve_update(preds, target, num_classes, pos_label) - - -def _average_precision_compute( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - pos_label: int, - sample_weights: Optional[Sequence] = None -) -> Union[List[torch.Tensor], torch.Tensor]: - precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) - # Return the step function integral - # The following works because the last entry of precision is - # guaranteed to be 1, as returned by precision_recall_curve - if num_classes == 1: - return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) - - res = [] - for p, r in zip(precision, recall): - res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) - return res +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_average_precision, ver_deprecate="1.3.0", ver_remove="1.5.0") def average_precision( preds: torch.Tensor, target: torch.Tensor, @@ -58,42 +28,6 @@ def average_precision( sample_weights: Optional[Sequence] = None, ) -> Union[List[torch.Tensor], torch.Tensor]: """ - Computes the average precision score. - - Args: - preds: predictions from model (logits or probabilities) - target: ground truth values - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - sample_weights: sample weights for each data point - - Returns: - tensor with average precision. If multiclass will return list - of such tensors, one for each class - - Example (binary case): - - >>> from pytorch_lightning.metrics.functional import average_precision - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 1]) - >>> average_precision(pred, target, pos_label=1) - tensor(1.) - - Example (multiclass case): - - >>> from pytorch_lightning.metrics.functional import average_precision - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.05, 0.75, 0.05]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> average_precision(pred, target, num_classes=5) - [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] - + .. deprecated:: + Use :func:`torchmetrics.functional.average_precision`. Will be removed in v1.5.0. """ - preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes, pos_label) - return _average_precision_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index b6d26237cf287..1b5be382a13af 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -14,29 +14,14 @@ from typing import Optional import torch -from torchmetrics.classification.stat_scores import _reduce_stat_scores +from torchmetrics.functional import precision as _precision +from torchmetrics.functional import precision_recall as _precision_recall +from torchmetrics.functional import recall as _recall -from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update -from pytorch_lightning.utilities import rank_zero_warn - - -def _precision_compute( - tp: torch.Tensor, - fp: torch.Tensor, - tn: torch.Tensor, - fn: torch.Tensor, - average: str, - mdmc_average: Optional[str], -) -> torch.Tensor: - return _reduce_stat_scores( - numerator=tp, - denominator=tp + fp, - weights=None if average != "weighted" else tp + fn, - average=average, - mdmc_average=mdmc_average, - ) +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_precision, ver_deprecate="1.3.0", ver_remove="1.5.0") def precision( preds: torch.Tensor, target: torch.Tensor, @@ -47,166 +32,14 @@ def precision( threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, - class_reduction: Optional[str] = None, ) -> torch.Tensor: - r""" - Computes `Precision `_: - - .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} - - Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and - false positives respecitively. With the use of ``top_k`` parameter, this metric can - generalize to Precision@K. - - The reduction method (how the precision scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. - - Args: - preds: Predictions from model (probabilities or labels) - target: Ground truth values - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics accross classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics accross classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. - - Should be left unset (``None``) for inputs with label predictions. - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - class_reduction: - .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. - - Return: - The shape of the returned tensor depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes - - Raises: - ValueError: - If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, - ``"samples"``, ``"none"`` or ``None``. - ValueError: - If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. - ValueError: - If ``average`` is set but ``num_classes`` is not provided. - ValueError: - If ``num_classes`` is set - and ``ignore_index`` is not in the range ``[0, num_classes)``. - - Example: - - >>> from pytorch_lightning.metrics.functional import precision - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> precision(preds, target, average='macro', num_classes=3) - tensor(0.1667) - >>> precision(preds, target, average='micro') - tensor(0.2500) - """ - if class_reduction: - rank_zero_warn( - "This `class_reduction` parameter was deprecated in v1.2.0 in favor of" - " `reduce`. It will be removed in v1.4.0", - DeprecationWarning, - ) - average = class_reduction - - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - allowed_mdmc_average = [None, "samplewise", "global"] - if mdmc_average not in allowed_mdmc_average: - raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") - - if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): - raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - reduce = "macro" if average in ["weighted", "none", None] else average - tp, fp, tn, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce=mdmc_average, - threshold=threshold, - num_classes=num_classes, - top_k=top_k, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - ) - - return _precision_compute(tp, fp, tn, fn, average, mdmc_average) - - -def _recall_compute( - tp: torch.Tensor, - fp: torch.Tensor, - tn: torch.Tensor, - fn: torch.Tensor, - average: str, - mdmc_average: Optional[str], -) -> torch.Tensor: - return _reduce_stat_scores( - numerator=tp, - denominator=tp + fn, - weights=None if average != "weighted" else tp + fn, - average=average, - mdmc_average=mdmc_average, - ) + .. deprecated:: + Use :func:`torchmetrics.functional.precision`. Will be removed in v1.5.0. + """ +@deprecated(target=_recall, ver_deprecate="1.3.0", ver_remove="1.5.0") def recall( preds: torch.Tensor, target: torch.Tensor, @@ -217,149 +50,14 @@ def recall( threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, - class_reduction: Optional[str] = None, ) -> torch.Tensor: - r""" - Computes `Recall `_: - - .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} - - Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and - false negatives respecitively. With the use of ``top_k`` parameter, this metric can - generalize to Recall@K. - - The reduction method (how the recall scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. - - Args: - preds: Predictions from model (probabilities, or labels) - target: Ground truth values - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics accross classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics accross classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. - - Should be left unset (``None``) for inputs with label predictions. - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - class_reduction: - .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. - - Return: - The shape of the returned tensor depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes - - Raises: - ValueError: - If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, - ``"samples"``, ``"none"`` or ``None``. - ValueError: - If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. - ValueError: - If ``average`` is set but ``num_classes`` is not provided. - ValueError: - If ``num_classes`` is set - and ``ignore_index`` is not in the range ``[0, num_classes)``. - - Example: - - >>> from pytorch_lightning.metrics.functional import recall - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> recall(preds, target, average='macro', num_classes=3) - tensor(0.3333) - >>> recall(preds, target, average='micro') - tensor(0.2500) - """ - if class_reduction: - rank_zero_warn( - "This `class_reduction` parameter was deprecated in v1.2.0 in favor of" - " `reduce`. It will be removed in v1.4.0", - DeprecationWarning, - ) - average = class_reduction - - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - allowed_mdmc_average = [None, "samplewise", "global"] - if mdmc_average not in allowed_mdmc_average: - raise ValueError("The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") - - if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): - raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - reduce = "macro" if average in ["weighted", "none", None] else average - tp, fp, tn, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce=mdmc_average, - threshold=threshold, - num_classes=num_classes, - top_k=top_k, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - ) - - return _recall_compute(tp, fp, tn, fn, average, mdmc_average) + .. deprecated:: + Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. + """ +@deprecated(target=_precision_recall, ver_deprecate="1.3.0", ver_remove="1.5.0") def precision_recall( preds: torch.Tensor, target: torch.Tensor, @@ -370,151 +68,8 @@ def precision_recall( threshold: float = 0.5, top_k: Optional[int] = None, is_multiclass: Optional[bool] = None, - class_reduction: Optional[str] = None, ) -> torch.Tensor: - r""" - Computes `Precision and Recall `_: - - .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} - - - .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} - - Where :math:`\text{TP}`m :math:`\text{FN}` and :math:`\text{FP}` represent the number - of true positives, false negatives and false positives respecitively. With the use of - ``top_k`` parameter, this metric can generalize to Recall@K and Precision@K. - - The reduction method (how the recall scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. - - Args: - preds: Predictions from model (probabilities, or labels) - target: Ground truth values - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, accross all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics accross classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics accross classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - Note that what is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. - - Should be left unset (``None``) for inputs with label predictions. - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - class_reduction: - .. warning :: This parameter is deprecated, use ``average``. Will be removed in v1.4.0. - - Return: - The function returns a tuple with two elements: precision and recall. Their shape - depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, they are a single element tensor - - If ``average in ['none', None]``, they are a tensor of shape ``(C, )``, where ``C`` stands for - the number of classes - - Raises: - ValueError: - If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, - ``"samples"``, ``"none"`` or ``None``. - ValueError: - If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. - ValueError: - If ``average`` is set but ``num_classes`` is not provided. - ValueError: - If ``num_classes`` is set - and ``ignore_index`` is not in the range ``[0, num_classes)``. - - Example: - - >>> from pytorch_lightning.metrics.functional import precision_recall - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> precision_recall(preds, target, average='macro', num_classes=3) - (tensor(0.1667), tensor(0.3333)) - >>> precision_recall(preds, target, average='micro') - (tensor(0.2500), tensor(0.2500)) - """ - if class_reduction: - rank_zero_warn( - "This `class_reduction` parameter was deprecated in v1.2.0 in favor of" - " `reduce`. It will be removed in v1.4.0", - DeprecationWarning, - ) - average = class_reduction - - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - allowed_mdmc_average = [None, "samplewise", "global"] - if mdmc_average not in allowed_mdmc_average: - raise ValueError("The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") - - if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): - raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - reduce = "macro" if average in ["weighted", "none", None] else average - tp, fp, tn, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce=mdmc_average, - threshold=threshold, - num_classes=num_classes, - top_k=top_k, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - ) - - precision = _precision_compute(tp, fp, tn, fn, average, mdmc_average) - recall = _recall_compute(tp, fp, tn, fn, average, mdmc_average) - - return precision, recall + .. deprecated:: + Use :func:`torchmetrics.functional.precision_recall`. Will be removed in v1.5.0. + """ diff --git a/pytorch_lightning/metrics/functional/precision_recall_curve.py b/pytorch_lightning/metrics/functional/precision_recall_curve.py index 0d5760ce789cd..d1d643ba70c22 100644 --- a/pytorch_lightning/metrics/functional/precision_recall_curve.py +++ b/pytorch_lightning/metrics/functional/precision_recall_curve.py @@ -14,140 +14,12 @@ from typing import List, Optional, Sequence, Tuple, Union import torch -import torch.nn.functional as F +from torchmetrics.functional import precision_recall_curve as _precision_recall_curve -from pytorch_lightning.utilities import rank_zero_warn - - -def _binary_clf_curve( - preds: torch.Tensor, - target: torch.Tensor, - sample_weights: Optional[Sequence] = None, - pos_label: int = 1., -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py - """ - if sample_weights is not None and not isinstance(sample_weights, torch.Tensor): - sample_weights = torch.tensor(sample_weights, device=preds.device, dtype=torch.float) - - # remove class dimension if necessary - if preds.ndim > target.ndim: - preds = preds[:, 0] - desc_score_indices = torch.argsort(preds, descending=True) - - preds = preds[desc_score_indices] - target = target[desc_score_indices] - - if sample_weights is not None: - weight = sample_weights[desc_score_indices] - else: - weight = 1. - - # pred typically has many tied values. Here we extract - # the indices associated with the distinct values. We also - # concatenate a value for the end of the curve. - distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0] - threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1) - target = (target == pos_label).to(torch.long) - tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] - - if sample_weights is not None: - # express fps as a cumsum to ensure fps is increasing even in - # the presence of floating point errors - fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] - else: - fps = 1 + threshold_idxs - tps - - return fps, tps, preds[threshold_idxs] - - -def _precision_recall_curve_update( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, -) -> Tuple[torch.Tensor, torch.Tensor, int, int]: - if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): - raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") - # single class evaluation - if len(preds.shape) == len(target.shape): - num_classes = 1 - if pos_label is None: - rank_zero_warn('`pos_label` automatically set 1.') - pos_label = 1 - preds = preds.flatten() - target = target.flatten() - - # multi class evaluation - if len(preds.shape) == len(target.shape) + 1: - if pos_label is not None: - rank_zero_warn( - 'Argument `pos_label` should be `None` when running' - f' multiclass precision recall curve. Got {pos_label}' - ) - if num_classes != preds.shape[1]: - raise ValueError( - f'Argument `num_classes` was set to {num_classes} in' - f' metric `precision_recall_curve` but detected {preds.shape[1]}' - ' number of classes from predictions' - ) - preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1) - target = target.flatten() - - return preds, target, num_classes, pos_label - - -def _precision_recall_curve_compute( - preds: torch.Tensor, - target: torch.Tensor, - num_classes: int, - pos_label: int, - sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: - - if num_classes == 1: - fps, tps, thresholds = _binary_clf_curve( - preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label - ) - - precision = tps / (tps + fps) - recall = tps / tps[-1] - - # stop when full recall attained - # and reverse the outputs so recall is decreasing - last_ind = torch.where(tps == tps[-1])[0][0] - sl = slice(0, last_ind.item() + 1) - - # need to call reversed explicitly, since including that to slice would - # introduce negative strides that are not yet supported in pytorch - precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)]) - - recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)]) - - thresholds = reversed(thresholds[sl]).clone() - - return precision, recall, thresholds - - # Recursively call per class - precision, recall, thresholds = [], [], [] - for c in range(num_classes): - preds_c = preds[:, c] - res = precision_recall_curve( - preds=preds_c, - target=target, - num_classes=1, - pos_label=c, - sample_weights=sample_weights, - ) - precision.append(res[0]) - recall.append(res[1]) - thresholds.append(res[2]) - - return precision, recall, thresholds +from pytorch_lightning.utilities.deprecation import deprecated +@deprecated(target=_precision_recall_curve, ver_deprecate="1.3.0", ver_remove="1.5.0") def precision_recall_curve( preds: torch.Tensor, target: torch.Tensor, @@ -155,72 +27,8 @@ def precision_recall_curve( pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, ) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], - List[torch.Tensor]]]: + List[torch.Tensor]], ]: """ - Computes precision-recall pairs for different thresholds. - - Args: - preds: predictions from model (probabilities) - target: ground truth labels - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - sample_weights: sample weights for each data point - - Returns: - 3-element tuple containing - - precision: - tensor where element i is the precision of predictions with - score >= thresholds[i] and the last element is 1. - If multiclass, this is a list of such tensors, one for each class. - recall: - tensor where element i is the recall of predictions with - score >= thresholds[i] and the last element is 0. - If multiclass, this is a list of such tensors, one for each class. - thresholds: - Thresholds used for computing precision/recall scores - - Example (binary case): - - >>> from pytorch_lightning.metrics.functional import precision_recall_curve - >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 0]) - >>> precision, recall, thresholds = precision_recall_curve(pred, target, pos_label=1) - >>> precision - tensor([0.6667, 0.5000, 0.0000, 1.0000]) - >>> recall - tensor([1.0000, 0.5000, 0.0000, 0.0000]) - >>> thresholds - tensor([1, 2, 3]) - - Raises: - ValueError: - If ``preds`` and ``target`` don't have the same number of dimensions, - or one additional dimension for ``preds``. - ValueError: - If the number of classes deduced from ``preds`` is not the same as the - ``num_classes`` provided. - - Example (multiclass case): - - >>> from pytorch_lightning.metrics.functional import precision_recall_curve - >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.05, 0.75, 0.05]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> precision, recall, thresholds = precision_recall_curve(pred, target, num_classes=5) - >>> precision # doctest: +NORMALIZE_WHITESPACE - [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), - tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] - >>> recall - [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] - >>> thresholds - [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] + .. deprecated:: + Use :func:`torchmetrics.functional.accuracy`. Will be removed in v1.5.0. """ - preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) - return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 39f5e0dca5075..74bec18e8ddb3 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -130,16 +130,6 @@ def test_v1_4_0_deprecated_metrics(): with pytest.deprecated_call(match='will be removed in v1.4'): precision_recall(torch.randint(0, 2, (10, 3, 3)), torch.randint(0, 2, (10, 3, 3))) - # Testing deprecation of class_reduction arg in the *new* precision - from pytorch_lightning.metrics.functional import precision - with pytest.deprecated_call(match='will be removed in v1.4'): - precision(torch.randint(0, 2, (10, )), torch.randint(0, 2, (10, )), class_reduction='micro') - - # Testing deprecation of class_reduction arg in the *new* recall - from pytorch_lightning.metrics.functional import recall - with pytest.deprecated_call(match='will be removed in v1.4'): - recall(torch.randint(0, 2, (10, )), torch.randint(0, 2, (10, )), class_reduction='micro') - from pytorch_lightning.metrics.functional.classification import auc with pytest.deprecated_call(match='will be removed in v1.4'): auc(torch.rand(10, ).sort().values, torch.rand(10, )) diff --git a/tests/metrics/classification/test_average_precision.py b/tests/metrics/classification/test_average_precision.py deleted file mode 100644 index 7cab20883e970..0000000000000 --- a/tests/metrics/classification/test_average_precision.py +++ /dev/null @@ -1,97 +0,0 @@ -from functools import partial - -import numpy as np -import pytest -import torch -from sklearn.metrics import average_precision_score as sk_average_precision_score - -from pytorch_lightning.metrics.classification.average_precision import AveragePrecision -from pytorch_lightning.metrics.functional.average_precision import average_precision -from tests.metrics.classification.inputs import _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES - -torch.manual_seed(42) - - -def _sk_average_precision_score(y_true, probas_pred, num_classes=1): - if num_classes == 1: - return sk_average_precision_score(y_true, probas_pred) - - res = [] - for i in range(num_classes): - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 - res.append(sk_average_precision_score(y_true_temp, probas_pred[:, i])) - return res - - -def _sk_avg_prec_binary_prob(preds, target, num_classes=1): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() - - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_avg_prec_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_avg_prec_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_avg_prec_multidim_multiclass_prob, NUM_CLASSES), - ] -) -class TestAveragePrecision(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=AveragePrecision, - sk_metric=partial(sk_metric, num_classes=num_classes), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes} - ) - - def test_average_precision_functional(self, preds, target, sk_metric, num_classes): - self.run_functional_metric_test( - preds, - target, - metric_functional=average_precision, - sk_metric=partial(sk_metric, num_classes=num_classes), - metric_args={"num_classes": num_classes}, - ) - - -@pytest.mark.parametrize( - ['scores', 'target', 'expected_score'], - [ - # Check the average_precision_score of a constant predictor is - # the TPR - # Generate a dataset with 25% of positives - # And a constant score - # The precision is then the fraction of positive whatever the recall - # is, as there is only one threshold: - pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25), - # With threshold 0.8 : 1 TP and 2 TN and one FN - pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75), - ] -) -def test_average_precision(scores, target, expected_score): - assert average_precision(scores, target) == expected_score diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py deleted file mode 100644 index c9e5467414832..0000000000000 --- a/tests/metrics/classification/test_precision_recall.py +++ /dev/null @@ -1,348 +0,0 @@ -from functools import partial -from typing import Callable, Optional - -import numpy as np -import pytest -import torch -from sklearn.metrics import precision_score, recall_score -from torchmetrics import Metric -from torchmetrics.classification.checks import _input_format_classification - -from pytorch_lightning.metrics import Precision, Recall -from pytorch_lightning.metrics.functional import precision, precision_recall, recall -from tests.metrics.classification.inputs import _input_binary, _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass as _input_mcls -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.classification.inputs import _input_multilabel as _input_mlb -from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD - -torch.manual_seed(42) - - -def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None): - if average == "none": - average = None - if num_classes == 1: - average = "binary" - - labels = list(range(num_classes)) - try: - labels.remove(ignore_index) - except ValueError: - pass - - sk_preds, sk_target, _ = _input_format_classification( - preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass - ) - sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - - sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels) - - if len(labels) != num_classes and not average: - sk_scores = np.insert(sk_scores, ignore_index, np.nan) - - return sk_scores - - -def _sk_prec_recall_multidim_multiclass( - preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average -): - preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass - ) - - if mdmc_average == "global": - preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) - target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) - - return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, ignore_index) - elif mdmc_average == "samplewise": - scores = [] - - for i in range(preds.shape[0]): - pred_i = preds[i, ...].T - target_i = target[i, ...].T - scores_i = _sk_prec_recall(pred_i, target_i, sk_fn, num_classes, average, False, ignore_index) - - scores.append(np.expand_dims(scores_i, 0)) - - return np.concatenate(scores).mean(axis=0) - - -@pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)]) -@pytest.mark.parametrize( - "average, mdmc_average, num_classes, ignore_index, match_str", - [ - ("wrong", None, None, None, "`average`"), - ("micro", "wrong", None, None, "`mdmc"), - ("macro", None, None, None, "number of classes"), - ("macro", None, 1, 0, "ignore_index"), - ], -) -def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str): - with pytest.raises(ValueError, match=match_str): - metric( - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, - ) - - with pytest.raises(ValueError, match=match_str): - fn_metric( - _input_binary.preds[0], - _input_binary.target[0], - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, - ) - - with pytest.raises(ValueError, match=match_str): - precision_recall( - _input_binary.preds[0], - _input_binary.target[0], - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, - ) - - -@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) -def test_zero_division(metric_class, metric_fn): - """ Test that zero_division works correctly (currently should just set to 0). """ - - preds = torch.tensor([1, 2, 1, 1]) - target = torch.tensor([2, 1, 2, 1]) - - cl_metric = metric_class(average="none", num_classes=3) - cl_metric(preds, target) - - result_cl = cl_metric.compute() - result_fn = metric_fn(preds, target, average="none", num_classes=3) - - assert result_cl[0] == result_fn[0] == 0 - - -@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) -def test_no_support(metric_class, metric_fn): - """This tests a rare edge case, where there is only one class present - in target, and ignore_index is set to exactly that class - and the - average method is equal to 'weighted'. - - This would mean that the sum of weights equals zero, and would, without - taking care of this case, return NaN. However, the reduction function - should catch that and set the metric to equal the value of zero_division - in this case (zero_division is for now not configurable and equals 0). - """ - - preds = torch.tensor([1, 1, 0, 0]) - target = torch.tensor([0, 0, 0, 0]) - - cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0) - cl_metric(preds, target) - - result_cl = cl_metric.compute() - result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=0) - - assert result_cl == result_fn == 0 - - -@pytest.mark.parametrize( - "metric_class, metric_fn, sk_fn", [(Recall, recall, recall_score), (Precision, precision, precision_score)] -) -@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -@pytest.mark.parametrize("ignore_index", [None, 0]) -@pytest.mark.parametrize( - "preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper", - [ - (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall), - (_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall), - (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_prec_recall), - (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_prec_recall), - (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_multidim_multiclass), - ( - _input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global", - _sk_prec_recall_multidim_multiclass - ), - (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_multidim_multiclass), - ( - _input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise", - _sk_prec_recall_multidim_multiclass - ), - ], -) -class TestPrecisionRecall(MetricTester): - - @pytest.mark.parametrize("ddp", [False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_precision_recall_class( - self, - ddp: bool, - dist_sync_on_step: bool, - preds: torch.Tensor, - target: torch.Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - sk_fn: Callable, - is_multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], - ): - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=metric_class, - sk_metric=partial( - sk_wrapper, - sk_fn=sk_fn, - average=average, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - mdmc_average=mdmc_average, - ), - dist_sync_on_step=dist_sync_on_step, - metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "is_multiclass": is_multiclass, - "ignore_index": ignore_index, - "mdmc_average": mdmc_average, - }, - check_dist_sync_on_step=True, - check_batch=True, - ) - - def test_precision_recall_fn( - self, - preds: torch.Tensor, - target: torch.Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - sk_fn: Callable, - is_multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], - ): - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - - self.run_functional_metric_test( - preds, - target, - metric_functional=metric_fn, - sk_metric=partial( - sk_wrapper, - sk_fn=sk_fn, - average=average, - num_classes=num_classes, - is_multiclass=is_multiclass, - ignore_index=ignore_index, - mdmc_average=mdmc_average, - ), - metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "is_multiclass": is_multiclass, - "ignore_index": ignore_index, - "mdmc_average": mdmc_average, - }, - ) - - -@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -def test_precision_recall_joint(average): - """A simple test of the joint precision_recall metric. - - No need to test this thorougly, as it is just a combination of precision and recall, - which are already tested thoroughly. - """ - - precision_result = precision( - _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES - ) - recall_result = recall( - _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES - ) - - prec_recall_result = precision_recall( - _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES - ) - - assert torch.equal(precision_result, prec_recall_result[0]) - assert torch.equal(recall_result, prec_recall_result[1]) - - -_mc_k_target = torch.tensor([0, 1, 2]) -_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) - - -@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) -@pytest.mark.parametrize( - "k, preds, target, average, expected_prec, expected_recall", - [ - (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)), - (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1 / 2), torch.tensor(1.0)), - (1, _ml_k_preds, _ml_k_target, "micro", torch.tensor(0.0), torch.tensor(0.0)), - (2, _ml_k_preds, _ml_k_target, "micro", torch.tensor(1 / 6), torch.tensor(1 / 3)), - ], -) -def test_top_k( - metric_class, - metric_fn, - k: int, - preds: torch.Tensor, - target: torch.Tensor, - average: str, - expected_prec: torch.Tensor, - expected_recall: torch.Tensor, -): - """A simple test to check that top_k works as expected. - - Just a sanity check, the tests in StatScores should already guarantee - the corectness of results. - """ - - class_metric = metric_class(top_k=k, average=average, num_classes=3) - class_metric.update(preds, target) - - if metric_class.__name__ == "Precision": - result = expected_prec - else: - result = expected_recall - - assert torch.equal(class_metric.compute(), result) - assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) diff --git a/tests/metrics/classification/test_precision_recall_curve.py b/tests/metrics/classification/test_precision_recall_curve.py deleted file mode 100644 index 6a60e1fd36fdd..0000000000000 --- a/tests/metrics/classification/test_precision_recall_curve.py +++ /dev/null @@ -1,97 +0,0 @@ -from functools import partial - -import numpy as np -import pytest -import torch -from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve - -from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve -from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve -from tests.metrics.classification.inputs import _input_binary_prob -from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.metrics.utils import MetricTester, NUM_CLASSES - -torch.manual_seed(42) - - -def _sk_precision_recall_curve(y_true, probas_pred, num_classes=1): - """ Adjusted comparison function that can also handles multiclass """ - if num_classes == 1: - return sk_precision_recall_curve(y_true, probas_pred) - - precision, recall, thresholds = [], [], [] - for i in range(num_classes): - y_true_temp = np.zeros_like(y_true) - y_true_temp[y_true == i] = 1 - res = sk_precision_recall_curve(y_true_temp, probas_pred[:, i]) - precision.append(res[0]) - recall.append(res[1]) - thresholds.append(res[2]) - return precision, recall, thresholds - - -def _sk_prec_rc_binary_prob(preds, target, num_classes=1): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_prec_rc_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.reshape(-1, num_classes).numpy() - sk_target = target.view(-1).numpy() - - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -def _sk_prec_rc_multidim_multiclass_prob(preds, target, num_classes=1): - sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() - sk_target = target.view(-1).numpy() - return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_prec_rc_binary_prob, 1), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_prec_rc_multiclass_prob, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_prec_rc_multidim_multiclass_prob, NUM_CLASSES), - ] -) -class TestPrecisionRecallCurve(MetricTester): - - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=PrecisionRecallCurve, - sk_metric=partial(sk_metric, num_classes=num_classes), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes} - ) - - def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_classes): - self.run_functional_metric_test( - preds, - target, - metric_functional=precision_recall_curve, - sk_metric=partial(sk_metric, num_classes=num_classes), - metric_args={"num_classes": num_classes}, - ) - - -@pytest.mark.parametrize( - ['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], - [pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])] -) -def test_pr_curve(pred, target, expected_p, expected_r, expected_t): - p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target)) - assert p.size() == r.size() - assert p.size(0) == t.size(0) + 1 - - assert torch.allclose(p, torch.tensor(expected_p).to(p)) - assert torch.allclose(r, torch.tensor(expected_r).to(r)) - assert torch.allclose(t, torch.tensor(expected_t).to(t)) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 44109d40b2efa..f9f3033427b1b 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -1,33 +1,7 @@ import pytest import torch -from pytorch_lightning import seed_everything from pytorch_lightning.metrics.functional.classification import dice_score -from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve - - -@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [ - pytest.param(1, 1., 42), - pytest.param(None, 1., 42), -]) -def test_binary_clf_curve(sample_weight, pos_label, exp_shape): - # TODO: move back the pred and target to test func arguments - # if you fix the array inside the function, you'd also have fix the shape, - # because when the array changes, you also have to fix the shape - seed_everything(0) - pred = torch.randint(low=51, high=99, size=(100, ), dtype=torch.float) / 100 - target = torch.tensor([0, 1] * 50, dtype=torch.int) - if sample_weight is not None: - sample_weight = torch.ones_like(pred) * sample_weight - - fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) - - assert isinstance(tps, torch.Tensor) - assert isinstance(fps, torch.Tensor) - assert isinstance(thresh, torch.Tensor) - assert tps.shape == (exp_shape, ) - assert fps.shape == (exp_shape, ) - assert thresh.shape == (exp_shape, ) @pytest.mark.parametrize(['pred', 'target', 'expected'], [ diff --git a/tests/metrics/test_remove_1-5_metrics.py b/tests/metrics/test_remove_1-5_metrics.py index f284a9d85bc47..41ccfb6da8015 100644 --- a/tests/metrics/test_remove_1-5_metrics.py +++ b/tests/metrics/test_remove_1-5_metrics.py @@ -16,8 +16,27 @@ import pytest import torch -from pytorch_lightning.metrics import Accuracy, AUC, AUROC, MetricCollection, ROC -from pytorch_lightning.metrics.functional import auc, auroc, roc +from pytorch_lightning.metrics import ( + Accuracy, + AUC, + AUROC, + AveragePrecision, + MetricCollection, + Precision, + PrecisionRecallCurve, + Recall, + ROC, +) +from pytorch_lightning.metrics.functional import ( + auc, + auroc, + average_precision, + precision, + precision_recall, + precision_recall_curve, + recall, + roc, +) from pytorch_lightning.metrics.functional.accuracy import accuracy from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot @@ -88,13 +107,58 @@ def test_v1_5_metric_auc_auroc(): target = torch.tensor([0, 1, 1, 1]) roc.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): - fpr, tpr, thresholds = roc(preds, target, pos_label=1) + fpr, tpr, thrs = roc(preds, target, pos_label=1) assert torch.equal(fpr, torch.tensor([0., 0., 0., 0., 1.])) assert torch.allclose(tpr, torch.tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]), atol=1e-4) - assert torch.equal(thresholds, torch.tensor([4, 3, 2, 1, 0])) + assert torch.equal(thrs, torch.tensor([4, 3, 2, 1, 0])) preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) target = torch.tensor([0, 0, 1, 1, 1]) auroc.warned = False with pytest.deprecated_call(match='It will be removed in v1.5.0'): assert auroc(preds, target) == torch.tensor(0.5) + + +def test_v1_5_metric_precision_recall(): + AveragePrecision.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + AveragePrecision() + + Precision.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + Precision() + + Recall.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + Recall() + + PrecisionRecallCurve.__init__.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + PrecisionRecallCurve() + + pred = torch.tensor([0, 1, 2, 3]) + target = torch.tensor([0, 1, 1, 1]) + average_precision.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert average_precision(pred, target) == torch.tensor(1.) + + precision.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert precision(pred, target) == torch.tensor(0.5) + + recall.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + assert recall(pred, target) == torch.tensor(0.5) + + precision_recall.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + prec, rc = precision_recall(pred, target) + assert prec == torch.tensor(0.5) + assert rc == torch.tensor(0.5) + + precision_recall_curve.warned = False + with pytest.deprecated_call(match='It will be removed in v1.5.0'): + prec, rc, thrs = precision_recall_curve(pred, target) + assert torch.equal(prec, torch.tensor([1., 1., 1., 1.])) + assert torch.allclose(rc, torch.tensor([1., 0.6667, 0.3333, 0.]), atol=1e-4) + assert torch.equal(thrs, torch.tensor([1, 2, 3]))