diff --git a/CHANGELOG.md b/CHANGELOG.md index 466ef6f929e83..57105e252dfb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,11 +17,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `max_fpr` parameter to `auroc` metric for computing partial auroc metric ([#3790](https://github.com/PyTorchLightning/pytorch-lightning/pull/3790)) +- `StatScores` metric to compute the number of true positives, false positives, true negatives and false negatives ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) + + ### Changed +- `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) ### Deprecated +- `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) ### Removed diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 5396403402072..47a9947b6a7c7 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -251,13 +251,62 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]]) ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]]) -In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class -but are actually binary/multi-label. For example, if both predictions and targets are 1d -binary tensors. Or it could be the other way around, you want to treat binary/multi-label -inputs as 2-class (multi-dimensional) multi-class inputs. + +Using the ``is_multiclass`` parameter +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In some cases, you might have inputs which appear to be (multi-dimensional) multi-class +but are actually binary/multi-label - for example, if both predictions and targets are +integer (binary) tensors. Or it could be the other way around, you want to treat +binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs. For these cases, the metrics where this distinction would make a difference, expose the -``is_multiclass`` argument. +``is_multiclass`` argument. Let's see how this is used on the example of +:class:`~pytorch_lightning.metrics.classification.StatScores` metric. + +First, let's consider the case with label predictions with 2 classes, which we want to +treat as binary. + +.. testcode:: + + from pytorch_lightning.metrics.functional import stat_scores + + # These inputs are supposed to be binary, but appear as multi-class + preds = torch.tensor([0, 1, 0]) + target = torch.tensor([1, 1, 0]) + +As you can see below, by default the inputs are treated +as multi-class. We can set ``is_multiclass=False`` to treat the inputs as binary - +which is the same as converting the predictions to float beforehand. + +.. doctest:: + + >>> stat_scores(preds, target, reduce='macro', num_classes=2) + tensor([[1, 1, 1, 0, 1], + [1, 0, 1, 1, 2]]) + >>> stat_scores(preds, target, reduce='macro', num_classes=1, is_multiclass=False) + tensor([[1, 0, 1, 1, 2]]) + >>> stat_scores(preds.float(), target, reduce='macro', num_classes=1) + tensor([[1, 0, 1, 1, 2]]) + +Next, consider the opposite example: inputs are binary (as predictions are probabilities), +but we would like to treat them as 2-class multi-class, to obtain the metric for both classes. + +.. testcode:: + + preds = torch.tensor([0.2, 0.7, 0.3]) + target = torch.tensor([1, 1, 0]) + +In this case we can set ``is_multiclass=True``, to treat the inputs as multi-class. + +.. doctest:: + + >>> stat_scores(preds, target, reduce='macro', num_classes=1) + tensor([[1, 0, 1, 1, 2]]) + >>> stat_scores(preds, target, reduce='macro', num_classes=2, is_multiclass=True) + tensor([[1, 1, 1, 0, 1], + [1, 0, 1, 1, 2]]) + Class Metrics (Classification) ------------------------------ @@ -323,6 +372,13 @@ ROC :noindex: +StatScores +~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.classification.StatScores + :noindex: + + Functional Metrics (Classification) ----------------------------------- @@ -444,7 +500,7 @@ select_topk [func] stat_scores [func] ~~~~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.stat_scores +.. autofunction:: pytorch_lightning.metrics.functional.stat_scores :noindex: diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index c792fc5e71b03..53f6b5b6a6123 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -24,6 +24,7 @@ ROC, FBeta, F1, + StatScores ) from pytorch_lightning.metrics.regression import ( # noqa: F401 diff --git a/pytorch_lightning/metrics/classification/__init__.py b/pytorch_lightning/metrics/classification/__init__.py index 78163a9673887..a338bfe44f942 100644 --- a/pytorch_lightning/metrics/classification/__init__.py +++ b/pytorch_lightning/metrics/classification/__init__.py @@ -19,3 +19,4 @@ from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall # noqa: F401 from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401 from pytorch_lightning.metrics.classification.roc import ROC # noqa: F401 +from pytorch_lightning.metrics.classification.stat_scores import StatScores # noqa: F401 diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index e248c132026a4..e50b948f389f3 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -21,7 +21,7 @@ class Accuracy(Metric): r""" - Computes `Accuracy `_: + Computes `Accuracy `__: .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) @@ -43,7 +43,7 @@ class Accuracy(Metric): Args: threshold: Threshold probability value for transforming probability predictions to binary - `(0,1)` predictions, in the case of binary or multi-label inputs. + (0,1) predictions, in the case of binary or multi-label inputs. top_k: Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The @@ -54,27 +54,29 @@ class Accuracy(Metric): Whether to compute subset accuracy for multi-label and multi-dimensional multi-class inputs (has no effect for other input types). - For multi-label inputs, if the parameter is set to `True`, then all labels for - each sample must be correctly predicted for the sample to count as correct. If it - is set to `False`, then all labels are counted separately - this is equivalent to - flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). - - For multi-dimensional multi-class inputs, if the parameter is set to `True`, then all - sub-sample (on the extra axis) must be correct for the sample to be counted as correct. - If it is set to `False`, then all sub-samples are counter separately - this is equivalent, - in the case of label predictions, to flattening the inputs beforehand (i.e. - ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter - still applies in both cases, if set. + - For multi-label inputs, if the parameter is set to ``True``, then all labels for + each sample must be correctly predicted for the sample to count as correct. If it + is set to ``False``, then all labels are counted separately - this is equivalent to + flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). + + - For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all + sub-sample (on the extra axis) must be correct for the sample to be counted as correct. + If it is set to ``False``, then all sub-samples are counter separately - this is equivalent, + in the case of label predictions, to flattening the inputs beforehand (i.e. + ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter + still applies in both cases, if set. + compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False + before returning the value at the step process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather Example: @@ -113,11 +115,11 @@ def __init__( self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - if not 0 <= threshold <= 1: - raise ValueError("The `threshold` should lie in the [0,1] interval.") + if not 0 < threshold < 1: + raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") - if top_k is not None and top_k <= 0: - raise ValueError("The `top_k` should be an integer larger than 1.") + if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): + raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") self.threshold = threshold self.top_k = top_k diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py index b3281cd60987c..cdd3102744c55 100644 --- a/pytorch_lightning/metrics/classification/hamming_distance.py +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -39,14 +39,15 @@ class HammingDistance(Metric): Args: threshold: Threshold probability value for transforming probability predictions to binary - `(0,1)` predictions, in the case of binary or multi-label inputs. + (0 or 1) predictions, in the case of binary or multi-label inputs. compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) dist_sync_fn: Callback that performs the allgather operation on the metric state. When ``None``, DDP will be used to perform the all gather. @@ -80,8 +81,8 @@ def __init__( self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - if not 0 <= threshold <= 1: - raise ValueError("The `threshold` should lie in the [0,1] interval.") + if not 0 < threshold < 1: + raise ValueError("The `threshold` should lie in the (0,1) interval.") self.threshold = threshold def update(self, preds: torch.Tensor, target: torch.Tensor): diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index bac9be59b1c9f..b9b8d7902976b 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -39,8 +39,8 @@ def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold if preds_float and (preds.min() < 0 or preds.max() > 1): raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") - if threshold > 1 or threshold < 0: - raise ValueError("The `threshold` should be a probability in [0,1].") + if not 0 < threshold < 1: + raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") if is_multiclass is False and target.max() > 1: raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") @@ -58,7 +58,7 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) care of that. It returns the name of the case in which the inputs fall, and the implied - number of classes (from the C dim for multi-class data, or extra dim(s) for + number of classes (from the ``C`` dim for multi-class data, or extra dim(s) for multi-label data). """ @@ -181,13 +181,19 @@ def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool): - if "multi-class" not in case or not preds_float: - raise ValueError( - "You have set `top_k` above 1, but your data is not (multi-dimensional) multi-class" - " with probability predictions." - ) + if case == "binary": + raise ValueError("You can not use `top_k` parameter with binary data.") + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError("The `top_k` has to be an integer larger than 0.") + if not preds_float: + raise ValueError("You have set `top_k`, but you do not have probability predictions.") if is_multiclass is False: raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.") + if case == "multi-label" and is_multiclass: + raise ValueError( + "If you want to transform multi-label data to 2 class multi-dimensional" + "multi-class data using `is_multiclass=True`, you can not use `top_k`." + ) if top_k >= implied_classes: raise ValueError("The `top_k` has to be strictly smaller than the `C` dimension of `preds`.") @@ -216,45 +222,36 @@ def _check_classification_inputs( When ``num_classes`` is not specified in these cases, consistency of the highest target value against ``C`` dimension is checked for (multi-dimensional) multi-class cases. - If ``top_k`` is set (not None) for inputs which are not (multi-dimensional) multi class - with probabilities, then an error is raised. Similarly if ``top_k`` is set to a number - that is higher than or equal to the ``C`` dimension of ``preds``. + If ``top_k`` is set (not None) for inputs that do not have probability predictions (and + are not binary), an error is raised. Similarly if ``top_k`` is set to a number that + is higher than or equal to the ``C`` dimension of ``preds``, an error is raised. Preds and target tensors are expected to be squeezed already - all dimensions should be - greater than 1, except perhaps the first one (N). + greater than 1, except perhaps the first one (``N``). Args: preds: Tensor with predictions (labels or probabilities) target: Tensor with ground truth labels, always integers (labels) threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + (0,1) predictions, in the case of binary or multi-label inputs. num_classes: Number of classes. If not explicitly set, the number of classes will be infered either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` tensor, where applicable. top_k: Number of highest probability entries for each sample to convert to 1s - relevant - only for (multi-dimensional) multi-class inputs with probability predictions. The - default value (``None``) will be interepreted as 1 for these inputs. + only for inputs with probability predictions. The default value (``None``) will be + interepreted as 1 for these inputs. If this parameter is set for multi-label inputs, + it will take precedence over threshold. - Should be left unset (``None``) for all other types of inputs. + Should be left unset (``None``) for inputs with label predictions. is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be (see :ref:`metrics: Input types` documentation section for - input classification and examples of the use of this parameter). Should be left at default - value (``None``) in most cases. + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. - The special cases where this parameter should be set are: - - - When you want to treat binary or multi-label inputs as multi-class or multi-dimensional - multi-class with 2 classes, respectively. The probabilities are interpreted as the - probability of the "1" class, and thresholding still applies as usual. In this case - the parameter should be set to ``True``. - - When you want to treat multi-class or multi-dimensional mulit-class inputs with 2 classes - as binary or multi-label inputs, respectively. This is mainly meant for the case when - inputs are labels, but will work if they are probabilities as well. For this case the - parameter should be set to ``False``. Return: case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or @@ -294,7 +291,7 @@ def _check_classification_inputs( _check_num_classes_ml(num_classes, is_multiclass, implied_classes) # Check that top_k is consistent - if top_k: + if top_k is not None: _check_top_k(top_k, case, implied_classes, is_multiclass, preds.is_floating_point()) return case @@ -364,7 +361,7 @@ def _input_format_classification( target: Tensor with ground truth labels, always integers (labels) threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5 + (0 or 1) predictions, in the case of binary or multi-label inputs. num_classes: Number of classes. If not explicitly set, the number of classes will be infered either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` @@ -377,27 +374,16 @@ def _input_format_classification( Should be left unset (``None``) for all other types of inputs. is_multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be (see :ref:`metrics: Input types` documentation section for - input classification and examples of the use of this parameter). Should be left at default - value (``None``) in most cases. - - The special cases where this parameter should be set are: - - - When you want to treat binary or multi-label inputs as multi-class or multi-dimensional - multi-class with 2 classes, respectively. The probabilities are interpreted as the - probability of the "1" class, and thresholding still applies as usual. In this case - the parameter should be set to ``True``. - - When you want to treat multi-class or multi-dimensional mulit-class inputs with 2 classes - as binary or multi-label inputs, respectively. This is mainly meant for the case when - inputs are labels, but will work if they are probabilities as well. For this case the - parameter should be set to ``False``. + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. Returns: - preds: binary tensor of shape (N, C) or (N, C, X) - target: binary tensor of shape (N, C) or (N, C, X) - case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or - 'multi-dim multi-class' + preds: binary tensor of shape ``(N, C)`` or ``(N, C, X)`` + target: binary tensor of shape ``(N, C)`` or ``(N, C, X)`` + case: The case the inputs fall in, one of ``'binary'``, ``'multi-class'``, ``'multi-label'`` or + ``'multi-dim multi-class'`` """ # Remove excess dimensions if preds.shape[0] == 1: @@ -419,21 +405,22 @@ def _input_format_classification( top_k=top_k, ) - top_k = top_k if top_k else 1 - - if case in ["binary", "multi-label"]: + if case in ["binary", "multi-label"] and not top_k: preds = (preds >= threshold).int() num_classes = num_classes if not is_multiclass else 2 + if case == "multi-label" and top_k: + preds = select_topk(preds, top_k) + if "multi-class" in case or is_multiclass: if preds.is_floating_point(): num_classes = preds.shape[1] - preds = select_topk(preds, top_k) + preds = select_topk(preds, top_k or 1) else: num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1 - preds = to_onehot(preds, max(2,num_classes)) + preds = to_onehot(preds, max(2, num_classes)) - target = to_onehot(target, max(2,num_classes)) + target = to_onehot(target, max(2, num_classes)) if is_multiclass is False: preds, target = preds[:, 1, ...], target[:, 1, ...] diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py new file mode 100644 index 0000000000000..d7b33ce1f8099 --- /dev/null +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -0,0 +1,250 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Any, Callable + +import torch +from pytorch_lightning.metrics import Metric +from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update, _stat_scores_compute + + +class StatScores(Metric): + """Computes the number of true positives, false positives, true negatives, false negatives. + Related to `Type I and Type II errors `__ + and the `confusion matrix `__. + + The reduction method (how the statistics are aggregated) is controlled by the + ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the + multi-dimensional multi-class case. + + Accepts all inputs listed in :ref:`metrics:Input types`. + + Args: + threshold: + Threshold probability value for transforming probability predictions to binary + (0 or 1) predictions, in the case of binary or multi-label inputs. + + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. + + reduce: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Counts the statistics by summing over all [sample, class] + combinations (globally). Each statistic is represented by a single integer. + - ``'macro'``: Counts the statistics for each class separately (over all samples). + Each statistic is represented by a ``(C,)`` tensor. Requires ``num_classes`` + to be set. + - ``'samples'``: Counts the statistics for each sample separately (over all classes). + Each statistic is represented by a ``(N, )`` 1d tensor. + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_reduce``. + + num_classes: + Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + + ignore_index: + Specify a class (label) to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and + ``reduce='macro'``, the class statistics for the ignored class will all be returned + as ``-1``. + + mdmc_reduce: + Defines how the multi-dimensional multi-class inputs are handeled. Should be + one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class (see :ref:`metrics:Input types` for the definition of input types). + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then the outputs are concatenated together. In each + sample the extra axes ``...`` are flattened to become the sub-sample axis, and + statistics for each sample are computed by treating the sub-sample axis as the + ``N`` axis for that sample. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs are + flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. + + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + compute_on_step: + Forward only calls ``update()`` and return ``None`` if this is set to ``False``. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step + process_group: + Specify the process group on which synchronization is called. + default: ``None`` (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the allgather. + + Example: + + >>> from pytorch_lightning.metrics.classification import StatScores + >>> preds = torch.tensor([1, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> stat_scores = StatScores(reduce='macro', num_classes=3) + >>> stat_scores(preds, target) + tensor([[0, 1, 2, 1, 1], + [1, 1, 1, 1, 2], + [1, 0, 3, 0, 1]]) + >>> stat_scores = StatScores(reduce='micro') + >>> stat_scores(preds, target) + tensor([2, 2, 6, 2, 4]) + + """ + + def __init__( + self, + threshold: float = 0.5, + top_k: Optional[int] = None, + reduce: str = "micro", + num_classes: Optional[int] = None, + ignore_index: Optional[int] = None, + mdmc_reduce: Optional[str] = None, + is_multiclass: Optional[bool] = None, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.reduce = reduce + self.mdmc_reduce = mdmc_reduce + self.num_classes = num_classes + self.threshold = threshold + self.is_multiclass = is_multiclass + self.ignore_index = ignore_index + self.top_k = top_k + + if not 0 < threshold < 1: + raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") + + if reduce not in ["micro", "macro", "samples"]: + raise ValueError(f"The `reduce` {reduce} is not valid.") + + if mdmc_reduce not in [None, "samplewise", "global"]: + raise ValueError(f"The `mdmc_reduce` {mdmc_reduce} is not valid.") + + if reduce == "macro" and (not num_classes or num_classes < 1): + raise ValueError("When you set `reduce` as 'macro', you have to provide the number of classes.") + + if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + if mdmc_reduce != "samplewise" and reduce != "samples": + if reduce == "micro": + zeros_shape = [] + elif reduce == "macro": + zeros_shape = (num_classes,) + default, reduce_fn = lambda: torch.zeros(zeros_shape, dtype=torch.long), "sum" + else: + default, reduce_fn = lambda: [], None + + for s in ("tp", "fp", "tn", "fn"): + self.add_state(s, default=default(), dist_reduce_fx=reduce_fn) + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. See :ref:`metrics:Input types` for more information + on input types. + + Args: + preds: Predictions from model (probabilities or labels) + target: Ground truth values + """ + + tp, fp, tn, fn = _stat_scores_update( + preds, + target, + reduce=self.reduce, + mdmc_reduce=self.mdmc_reduce, + threshold=self.threshold, + num_classes=self.num_classes, + top_k=self.top_k, + is_multiclass=self.is_multiclass, + ignore_index=self.ignore_index, + ) + + # Update states + if self.reduce != "samples" and self.mdmc_reduce != "samplewise": + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + else: + self.tp.append(tp) + self.fp.append(fp) + self.tn.append(tn) + self.fn.append(fn) + + def compute(self) -> torch.Tensor: + """ + Computes the stat scores based on inputs passed in to ``update`` previously. + + Return: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The + shape depends on the ``reduce`` and ``mdmc_reduce`` (in case of multi-dimensional + multi-class data) parameters: + + - If the data is not multi-dimensional multi-class, then + + - If ``reduce='micro'``, the shape will be ``(5, )`` + - If ``reduce='macro'``, the shape will be ``(C, 5)``, + where ``C`` stands for the number of classes + - If ``reduce='samples'``, the shape will be ``(N, 5)``, where ``N`` stands for + the number of samples + + - If the data is multi-dimensional multi-class and ``mdmc_reduce='global'``, then + + - If ``reduce='micro'``, the shape will be ``(5, )`` + - If ``reduce='macro'``, the shape will be ``(C, 5)`` + - If ``reduce='samples'``, the shape will be ``(N*X, 5)``, where ``X`` stands for + the product of sizes of all "extra" dimensions of the data (i.e. all dimensions + except for ``C`` and ``N``) + + - If the data is multi-dimensional multi-class and ``mdmc_reduce='samplewise'``, then + + - If ``reduce='micro'``, the shape will be ``(N, 5)`` + - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` + - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` + + """ + if isinstance(self.tp, list): + tp = torch.cat(self.tp) + fp = torch.cat(self.fp) + tn = torch.cat(self.tn) + fn = torch.cat(self.fn) + else: + tp, fp, tn, fn = self.tp, self.fp, self.tn, self.fn + + return _stat_scores_compute(tp, fp, tn, fn) diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 1b28f534f80e7..3aa533504be6b 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -24,7 +24,6 @@ precision, precision_recall, recall, - stat_scores, stat_scores_multiple_classes, to_categorical, to_onehot, @@ -44,3 +43,4 @@ from pytorch_lightning.metrics.functional.roc import roc # noqa: F401 from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity # noqa: F401 from pytorch_lightning.metrics.functional.ssim import ssim # noqa: F401 +from pytorch_lightning.metrics.functional.stat_scores import stat_scores # noqa: F401 diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 8ba0e49b881b8..23adfc88ecd98 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -23,6 +23,9 @@ def _accuracy_update( preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) + if mode == "multi-label" and top_k: + raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") + if mode == "binary" or (mode == "multi-label" and subset_accuracy): correct = (preds == target).all(dim=1).sum() total = torch.tensor(target.shape[0], device=target.device) @@ -51,8 +54,7 @@ def accuracy( top_k: Optional[int] = None, subset_accuracy: bool = False, ) -> torch.Tensor: - r""" - Computes `Accuracy `_: + r"""Computes `Accuracy `_: .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) @@ -76,7 +78,7 @@ def accuracy( target: Ground truth labels threshold: Threshold probability value for transforming probability predictions to binary - `(0,1)` predictions, in the case of binary or multi-label inputs. + (0,1) predictions, in the case of binary or multi-label inputs. top_k: Number of highest probability predictions considered to find the correct label, relevant only for (multi-dimensional) multi-class inputs with probability predictions. The @@ -87,17 +89,17 @@ def accuracy( Whether to compute subset accuracy for multi-label and multi-dimensional multi-class inputs (has no effect for other input types). - For multi-label inputs, if the parameter is set to `True`, then all labels for - each sample must be correctly predicted for the sample to count as correct. If it - is set to `False`, then all labels are counted separately - this is equivalent to - flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). + - For multi-label inputs, if the parameter is set to ``True``, then all labels for + each sample must be correctly predicted for the sample to count as correct. If it + is set to ``False``, then all labels are counted separately - this is equivalent to + flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). - For multi-dimensional multi-class inputs, if the parameter is set to `True`, then all - sub-sample (on the extra axis) must be correct for the sample to be counted as correct. - If it is set to `False`, then all sub-samples are counter separately - this is equivalent, - in the case of label predictions, to flattening the inputs beforehand (i.e. - ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter - still applies in both cases, if set. + - For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all + sub-sample (on the extra axis) must be correct for the sample to be counted as correct. + If it is set to ``False``, then all sub-samples are counter separately - this is equivalent, + in the case of label predictions, to flattening the inputs beforehand (i.e. + ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter + still applies in both cases, if set. Example: @@ -113,8 +115,5 @@ def accuracy( tensor(0.6667) """ - if top_k is not None and top_k <= 0: - raise ValueError("The `top_k` should be an integer larger than 1.") - correct, total = _accuracy_update(preds, target, threshold, top_k, subset_accuracy) return _accuracy_compute(correct, total) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index b1d2893412582..094feeb6f729a 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -139,39 +139,15 @@ def stat_scores_multiple_classes( Calculates the number of true positive, false positive, true negative and false negative for each class - Args: - pred: prediction tensor - target: target tensor - num_classes: number of classes if known - argmax_dim: if pred is a tensor of probabilities, this indicates the - axis the argmax transformation will be applied over - reduction: a method to reduce metric score over labels (default: none) - Available reduction methods: - - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements - - Return: - True Positive, False Positive, True Negative, False Negative, Support - - Example: - - >>> x = torch.tensor([1, 2, 3]) - >>> y = torch.tensor([0, 2, 3]) - >>> tps, fps, tns, fns, sups = stat_scores_multiple_classes(x, y) - >>> tps - tensor([0., 0., 1., 1.]) - >>> fps - tensor([0., 1., 0., 0.]) - >>> tns - tensor([2., 2., 2., 2.]) - >>> fns - tensor([1., 0., 0., 0.]) - >>> sups - tensor([1., 0., 1., 1.]) + .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.stat_scores` """ + + rank_zero_warn( + "This `stat_scores_multiple_classes` was deprecated in v1.2.0 in favor of" + " `from pytorch_lightning.metrics.functional import stat_scores`." + " It will be removed in v1.4.0", DeprecationWarning + ) if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py index 7d8ecafd08b00..a814055061840 100644 --- a/pytorch_lightning/metrics/functional/hamming_distance.py +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -55,7 +55,7 @@ def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float target: Ground truth threshold: Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. + (0 or 1) predictions, in the case of binary or multi-label inputs. Example: diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py new file mode 100644 index 0000000000000..27d46ee31c39c --- /dev/null +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -0,0 +1,281 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Optional + +import torch +from pytorch_lightning.metrics.classification.helpers import _input_format_classification + + +def _del_column(tensor: torch.Tensor, index: int): + """ Delete the column at index.""" + + return torch.cat([tensor[:, :index], tensor[:, (index + 1) :]], 1) + + +def _stat_scores( + preds: torch.Tensor, target: torch.Tensor, reduce: str = "micro" +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate the number of tp, fp, tn, fn. + + Args: + preds: + An ``(N, C)`` or ``(N, C, X)`` tensor of predictions (0 or 1) + target: + An ``(N, C)`` or ``(N, C, X)`` tensor of true labels (0 or 1) + reduce: + One of ``'micro'``, ``'macro'``, ``'samples'`` + + Return: + Returns a list of 4 tensors; tp, fp, tn, fn. + The shape of the returned tensors depnds on the shape of the inputs + and the ``reduce`` parameter: + + If inputs are of the shape ``(N, C)``, then + - If ``reduce='micro'``, the returned tensors are 1 element tensors + - If ``reduce='macro'``, the returned tensors are ``(C,)`` tensors + - If ``reduce'samples'``, the returned tensors are ``(N,)`` tensors + + If inputs are of the shape ``(N, C, X)``, then + - If ``reduce='micro'``, the returned tensors are ``(N,)`` tensors + - If ``reduce='macro'``, the returned tensors are ``(N,C)`` tensors + - If ``reduce='samples'``, the returned tensors are ``(N,X)`` tensors + """ + if reduce == "micro": + dim = [0, 1] if preds.ndim == 2 else [1, 2] + elif reduce == "macro": + dim = 0 if preds.ndim == 2 else 2 + elif reduce == "samples": + dim = 1 + + true_pred, false_pred = target == preds, target != preds + pos_pred, neg_pred = preds == 1, preds == 0 + + tp = (true_pred * pos_pred).sum(dim=dim) + fp = (false_pred * pos_pred).sum(dim=dim) + + tn = (true_pred * neg_pred).sum(dim=dim) + fn = (false_pred * neg_pred).sum(dim=dim) + + return tp.long(), fp.long(), tn.long(), fn.long() + + +def _stat_scores_update( + preds: torch.Tensor, + target: torch.Tensor, + reduce: str = "micro", + mdmc_reduce: Optional[str] = None, + num_classes: Optional[int] = None, + top_k: Optional[int] = None, + threshold: float = 0.5, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + + preds, target, _ = _input_format_classification( + preds, target, threshold=threshold, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k + ) + + if ignore_index is not None and not 0 <= ignore_index < preds.shape[1]: + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {preds.shape[0]} classes") + + if ignore_index is not None and preds.shape[1] == 1: + raise ValueError("You can not use `ignore_index` with binary data.") + + if preds.ndim == 3: + if not mdmc_reduce: + raise ValueError( + "When your inputs are multi-dimensional multi-class, you have to set the `mdmc_reduce` parameter" + ) + if mdmc_reduce == "global": + preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) + target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) + + # Delete what is in ignore_index, if applicable (and classes don't matter): + if ignore_index and reduce != "macro": + preds = _del_column(preds, ignore_index) + target = _del_column(target, ignore_index) + + tp, fp, tn, fn = _stat_scores(preds, target, reduce=reduce) + + # Take care of ignore_index + if ignore_index and reduce == "macro": + tp[..., ignore_index] = -1 + fp[..., ignore_index] = -1 + tn[..., ignore_index] = -1 + fn[..., ignore_index] = -1 + + return tp, fp, tn, fn + + +def _stat_scores_compute(tp: torch.Tensor, fp: torch.Tensor, tn: torch.Tensor, fn: torch.Tensor) -> torch.Tensor: + + outputs = [ + tp.unsqueeze(-1), + fp.unsqueeze(-1), + tn.unsqueeze(-1), + fn.unsqueeze(-1), + tp.unsqueeze(-1) + fn.unsqueeze(-1), # support + ] + outputs = torch.cat(outputs, -1) + outputs = torch.where(outputs < 0, torch.tensor(-1, device=outputs.device), outputs) + + return outputs + + +def stat_scores( + preds: torch.Tensor, + target: torch.Tensor, + reduce: str = "micro", + mdmc_reduce: Optional[str] = None, + num_classes: Optional[int] = None, + top_k: Optional[int] = None, + threshold: float = 0.5, + is_multiclass: Optional[bool] = None, + ignore_index: Optional[int] = None, +) -> torch.Tensor: + """Computes the number of true positives, false positives, true negatives, false negatives. + Related to `Type I and Type II errors `__ + and the `confusion matrix `__. + + The reduction method (how the statistics are aggregated) is controlled by the + ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the + multi-dimensional multi-class case. Accepts all inputs listed in :ref:`metrics:Input types`. + + Args: + preds: Predictions from model (probabilities or labels) + target: Ground truth values + threshold: + Threshold probability value for transforming probability predictions to binary + (0 or 1) predictions, in the case of binary or multi-label inputs. + + top_k: + Number of highest probability entries for each sample to convert to 1s - relevant + only for inputs with probability predictions. If this parameter is set for multi-label + inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, + this parameter defaults to 1. + + Should be left unset (``None``) for inputs with label predictions. + + reduce: + Defines the reduction that is applied. Should be one of the following: + + - ``'micro'`` [default]: Counts the statistics by summing over all [sample, class] + combinations (globally). Each statistic is represented by a single integer. + - ``'macro'``: Counts the statistics for each class separately (over all samples). + Each statistic is represented by a ``(C,)`` tensor. Requires ``num_classes`` + to be set. + - ``'samples'``: Counts the statistics for each sample separately (over all classes). + Each statistic is represented by a ``(N, )`` 1d tensor. + + Note that what is considered a sample in the multi-dimensional multi-class case + depends on the value of ``mdmc_reduce``. + + num_classes: + Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. + + ignore_index: + Specify a class (label) to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. If an index is ignored, and + ``reduce='macro'``, the class statistics for the ignored class will all be returned + as ``-1``. + + mdmc_reduce: + Defines how the multi-dimensional multi-class inputs are handeled. Should be + one of the following: + + - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional + multi-class (see :ref:`metrics:Input types` for the definition of input types). + + - ``'samplewise'``: In this case, the statistics are computed separately for each + sample on the ``N`` axis, and then the outputs are concatenated together. In each + sample the extra axes ``...`` are flattened to become the sub-sample axis, and + statistics for each sample are computed by treating the sub-sample axis as the + ``N`` axis for that sample. + + - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs are + flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they + were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. + + is_multiclass: + Used only in certain special cases, where you want to treat inputs as a different type + than what they appear to be. See the parameter's + :ref:`documentation section ` + for a more detailed explanation and examples. + + Return: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The + shape depends on the ``reduce`` and ``mdmc_reduce`` (in case of multi-dimensional + multi-class data) parameters: + + - If the data is not multi-dimensional multi-class, then + + - If ``reduce='micro'``, the shape will be ``(5, )`` + - If ``reduce='macro'``, the shape will be ``(C, 5)``, + where ``C`` stands for the number of classes + - If ``reduce='samples'``, the shape will be ``(N, 5)``, where ``N`` stands for + the number of samples + + - If the data is multi-dimensional multi-class and ``mdmc_reduce='global'``, then + + - If ``reduce='micro'``, the shape will be ``(5, )`` + - If ``reduce='macro'``, the shape will be ``(C, 5)`` + - If ``reduce='samples'``, the shape will be ``(N*X, 5)``, where ``X`` stands for + the product of sizes of all "extra" dimensions of the data (i.e. all dimensions + except for ``C`` and ``N``) + + - If the data is multi-dimensional multi-class and ``mdmc_reduce='samplewise'``, then + + - If ``reduce='micro'``, the shape will be ``(N, 5)`` + - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` + - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` + + Example: + + >>> from pytorch_lightning.metrics.functional import stat_scores + >>> preds = torch.tensor([1, 0, 2, 1]) + >>> target = torch.tensor([1, 1, 2, 0]) + >>> stat_scores(preds, target, reduce='macro', num_classes=3) + tensor([[0, 1, 2, 1, 1], + [1, 1, 1, 1, 2], + [1, 0, 3, 0, 1]]) + >>> stat_scores(preds, target, reduce='micro') + tensor([2, 2, 6, 2, 4]) + + """ + + if reduce not in ["micro", "macro", "samples"]: + raise ValueError(f"The `reduce` {reduce} is not valid.") + + if mdmc_reduce not in [None, "samplewise", "global"]: + raise ValueError(f"The `mdmc_reduce` {mdmc_reduce} is not valid.") + + if reduce == "macro" and (not num_classes or num_classes < 1): + raise ValueError("When you set `reduce` as 'macro', you have to provide the number of classes.") + + if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + tp, fp, tn, fn = _stat_scores_update( + preds, + target, + reduce=reduce, + mdmc_reduce=mdmc_reduce, + top_k=top_k, + threshold=threshold, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + ) + return _stat_scores_compute(tp, fp, tn, fn) diff --git a/tests/deprecated_api/test_remove_1-4.py b/tests/deprecated_api/test_remove_1-4.py index 806916cdce344..9a7a970aecaf7 100644 --- a/tests/deprecated_api/test_remove_1-4.py +++ b/tests/deprecated_api/test_remove_1-4.py @@ -13,6 +13,7 @@ # limitations under the License. """Test deprecated functionality which will be removed in vX.Y.Z""" import pytest +import torch from pytorch_lightning import Trainer from tests.deprecated_api import _soft_unimport_module @@ -66,3 +67,9 @@ def test_v1_4_0_deprecated_trainer_attributes(): trainer.use_horovod = True assert trainer.use_horovod + + +def test_v1_4_0_deprecated_metrics(): + from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes + with pytest.deprecated_call(match='will be removed in v1.4'): + stat_scores_multiple_classes(pred=torch.tensor([0, 1]), target=torch.tensor([0, 1])) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 30d3f06707301..8cfe1dd46ec50 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -124,6 +124,7 @@ def _mlmd_prob_to_mc_preds_tr(x): (_ml_prob, None, None, None, "multi-label", _thrs, _idn), (_ml, None, False, None, "multi-dim multi-class", _idn, _idn), (_ml_prob, None, None, None, "multi-label", _ml_preds_tr, _rshp1), + (_ml_prob, None, None, 2, "multi-label", _top2, _rshp1), (_mlmd, None, False, None, "multi-dim multi-class", _rshp1, _rshp1), (_mc, NUM_CLASSES, None, None, "multi-class", _onehot, _onehot), (_mc_prob, None, None, None, "multi-class", _top1, _onehot), @@ -199,9 +200,11 @@ def test_threshold(): ######################################################################## -def test_incorrect_threshold(): +@pytest.mark.parametrize("threshold", [-0.5, 0.0, 1.0, 1.5]) +def test_incorrect_threshold(threshold): + preds, target = rand(size=(7,)), randint(high=2, size=(7,)) with pytest.raises(ValueError): - _input_format_classification(preds=rand(size=(7,)), target=randint(high=2, size=(7,)), threshold=1.5) + _input_format_classification(preds, target, threshold=threshold) @pytest.mark.parametrize( @@ -272,19 +275,25 @@ def test_incorrect_inputs(preds, target, num_classes, is_multiclass): @pytest.mark.parametrize( "preds, target, num_classes, is_multiclass, top_k", [ - # Topk set with non (md)mc prob data + # Topk set with non (md)mc or ml prob data (_bin.preds[0], _bin.target[0], None, None, 2), (_bin_prob.preds[0], _bin_prob.target[0], None, None, 2), (_mc.preds[0], _mc.target[0], None, None, 2), (_ml.preds[0], _ml.target[0], None, None, 2), (_mlmd.preds[0], _mlmd.target[0], None, None, 2), - (_ml_prob.preds[0], _ml_prob.target[0], None, None, 2), - (_mlmd_prob.preds[0], _mlmd_prob.target[0], None, None, 2), (_mdmc.preds[0], _mdmc.target[0], None, None, 2), + # top_k = 0 + (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, None, 0), + # top_k = float + (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, None, 0.123), # top_k =2 with 2 classes, is_multiclass=False (_mc_prob_2cls.preds[0], _mc_prob_2cls.target[0], None, False, 2), # top_k = number of classes (C dimension) (_mc_prob.preds[0], _mc_prob.target[0], None, None, NUM_CLASSES), + # is_multiclass = True for ml prob inputs, top_k set + (_ml_prob.preds[0], _ml_prob.target[0], None, True, 2), + # top_k = num_classes for ml prob inputs + (_ml_prob.preds[0], _ml_prob.target[0], None, True, NUM_CLASSES), ], ) def test_incorrect_inputs_topk(preds, target, num_classes, is_multiclass, top_k): diff --git a/tests/metrics/classification/test_stat_scores.py b/tests/metrics/classification/test_stat_scores.py new file mode 100644 index 0000000000000..62fc8096089a0 --- /dev/null +++ b/tests/metrics/classification/test_stat_scores.py @@ -0,0 +1,231 @@ +from functools import partial +from typing import Callable, Optional + +import numpy as np +import pytest +import torch +from sklearn.metrics import multilabel_confusion_matrix + +from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from pytorch_lightning.metrics import StatScores +from pytorch_lightning.metrics.functional import stat_scores +from tests.metrics.classification.inputs import ( + _binary_inputs, + _binary_prob_inputs, + _multiclass_inputs, + _multiclass_prob_inputs as _mc_prob, + _multilabel_inputs, + _multilabel_prob_inputs as _ml_prob, + _multidim_multiclass_inputs as _mdmc, + _multidim_multiclass_prob_inputs as _mdmc_prob, +) +from tests.metrics.utils import NUM_CLASSES, THRESHOLD, MetricTester + +torch.manual_seed(42) + + +def _sk_stat_scores(preds, target, reduce, num_classes, is_multiclass, ignore_index, top_k, mdmc_reduce=None): + preds, target, _ = _input_format_classification( + preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k + ) + sk_preds, sk_target = preds.numpy(), target.numpy() + + if reduce != "macro" and ignore_index and preds.shape[1] > 1: + sk_preds = np.delete(sk_preds, ignore_index, 1) + sk_target = np.delete(sk_target, ignore_index, 1) + + if preds.shape[1] == 1 and reduce == "samples": + sk_target = sk_target.T + sk_preds = sk_preds.T + + sk_stats = multilabel_confusion_matrix( + sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 + ) + + if preds.shape[1] == 1 and reduce != "samples": + sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] + else: + sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] + + if reduce == "micro": + sk_stats = sk_stats.sum(axis=0, keepdims=True) + + sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) + + if reduce == "micro": + sk_stats = sk_stats[0] + + if reduce == "macro" and ignore_index and preds.shape[1]: + sk_stats[ignore_index, :] = -1 + + return sk_stats + + +def _sk_stat_scores_mdmc(preds, target, reduce, mdmc_reduce, num_classes, is_multiclass, ignore_index, top_k): + preds, target, _ = _input_format_classification( + preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass, top_k=top_k + ) + + if mdmc_reduce == "global": + shape_permute = list(range(preds.ndim)) + shape_permute[1] = shape_permute[-1] + shape_permute[2:] = range(1, len(shape_permute) - 1) + + preds = preds.permute(*shape_permute).reshape(-1, preds.shape[1]) + target = target.permute(*shape_permute).reshape(-1, target.shape[1]) + + return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k) + elif mdmc_reduce == "samplewise": + scores = [] + + for i in range(preds.shape[0]): + pred_i = preds[i, ...].T + target_i = target[i, ...].T + scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index, top_k) + + scores.append(np.expand_dims(scores_i, 0)) + + return np.concatenate(scores) + + +@pytest.mark.parametrize( + "reduce, mdmc_reduce, num_classes, inputs, ignore_index", + [ + ["unknown", None, None, _binary_inputs, None], + ["micro", "unknown", None, _binary_inputs, None], + ["macro", None, None, _binary_inputs, None], + ["micro", None, None, _mdmc_prob, None], + ["micro", None, None, _binary_prob_inputs, 0], + ["micro", None, None, _mc_prob, NUM_CLASSES], + ["micro", None, NUM_CLASSES, _mc_prob, NUM_CLASSES], + ], +) +def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index): + """Test a combination of parameters that are invalid and should raise an error. + + This includes invalid ``reduce`` and ``mdmc_reduce`` parameter values, not setting + ``num_classes`` when ``reduce='macro'`, not setting ``mdmc_reduce`` when inputs + are multi-dim multi-class``, setting ``ignore_index`` when inputs are binary, as well + as setting ``ignore_index`` to a value higher than the number of classes. + """ + with pytest.raises(ValueError): + stat_scores( + inputs.preds[0], inputs.target[0], reduce, mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index + ) + + with pytest.raises(ValueError): + sts = StatScores(reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index) + sts(inputs.preds[0], inputs.target[0]) + + +def test_wrong_threshold(): + with pytest.raises(ValueError): + StatScores(threshold=1.5) + + +@pytest.mark.parametrize("ignore_index", [None, 1]) +@pytest.mark.parametrize("reduce", ["micro", "macro", "samples"]) +@pytest.mark.parametrize( + "preds, target, sk_fn, mdmc_reduce, num_classes, is_multiclass, top_k", + [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_stat_scores, None, 1, None, None), + (_binary_inputs.preds, _binary_inputs.target, _sk_stat_scores, None, 1, False, None), + (_ml_prob.preds, _ml_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None), + (_ml_prob.preds, _ml_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2), + (_multilabel_inputs.preds, _multilabel_inputs.target, _sk_stat_scores, None, NUM_CLASSES, False, None), + (_mc_prob.preds, _mc_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None), + (_mc_prob.preds, _mc_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2), + (_multiclass_inputs.preds, _multiclass_inputs.target, _sk_stat_scores, None, NUM_CLASSES, None, None), + (_mdmc.preds, _mdmc.target, _sk_stat_scores_mdmc, "samplewise", NUM_CLASSES, None, None), + (_mdmc_prob.preds, _mdmc_prob.target, _sk_stat_scores_mdmc, "samplewise", NUM_CLASSES, None, None), + (_mdmc.preds, _mdmc.target, _sk_stat_scores_mdmc, "global", NUM_CLASSES, None, None), + (_mdmc_prob.preds, _mdmc_prob.target, _sk_stat_scores_mdmc, "global", NUM_CLASSES, None, None), + ], +) +class TestStatScores(MetricTester): + # DDP tests temporarily disabled due to hanging issues + @pytest.mark.parametrize("ddp", [False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_stat_scores_class( + self, + ddp: bool, + dist_sync_on_step: bool, + sk_fn: Callable, + preds: torch.Tensor, + target: torch.Tensor, + reduce: str, + mdmc_reduce: Optional[str], + num_classes: Optional[int], + is_multiclass: Optional[bool], + ignore_index: Optional[int], + top_k: Optional[int], + ): + if ignore_index and preds.ndim == 2: + pytest.skip("Skipping ignore_index test with binary inputs.") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=StatScores, + sk_metric=partial( + sk_fn, + reduce=reduce, + mdmc_reduce=mdmc_reduce, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + top_k=top_k, + ), + dist_sync_on_step=dist_sync_on_step, + metric_args={ + "num_classes": num_classes, + "reduce": reduce, + "mdmc_reduce": mdmc_reduce, + "threshold": THRESHOLD, + "is_multiclass": is_multiclass, + "ignore_index": ignore_index, + "top_k": top_k, + }, + check_dist_sync_on_step=True, + check_batch=True, + ) + + def test_stat_scores_fn( + self, + sk_fn: Callable, + preds: torch.Tensor, + target: torch.Tensor, + reduce: str, + mdmc_reduce: Optional[str], + num_classes: Optional[int], + is_multiclass: Optional[bool], + ignore_index: Optional[int], + top_k: Optional[int], + ): + if ignore_index and preds.ndim == 2: + pytest.skip("Skipping ignore_index test with binary inputs.") + + self.run_functional_metric_test( + preds, + target, + metric_functional=stat_scores, + sk_metric=partial( + sk_fn, + reduce=reduce, + mdmc_reduce=mdmc_reduce, + num_classes=num_classes, + is_multiclass=is_multiclass, + ignore_index=ignore_index, + top_k=top_k, + ), + metric_args={ + "num_classes": num_classes, + "reduce": reduce, + "mdmc_reduce": mdmc_reduce, + "threshold": THRESHOLD, + "is_multiclass": is_multiclass, + "ignore_index": ignore_index, + "top_k": top_k, + }, + )