From 20eab4309b9bcde7a3157a1e3fb72711048d6c7c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 22 Nov 2022 14:45:52 +0100 Subject: [PATCH] Code cleaning after classification refactor 2/n (#1252) * functionals cleanup * remove old functions * revert stat score due to dice * clean docstring * more docstring cleaning * remaining changes to impl * remove old warning * fix arg ordering * fix import * try fixing docs * fix integration testing * fix top_k arg * fix broken tests * fix more unittests * fix more unittests * another fix * add tasks * fix more doctests * fix mypy Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Co-authored-by: Jirka Borovec Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 + README.md | 8 +- .../classification/precision_recall.rst | 16 - docs/source/pages/lightning.rst | 14 +- docs/source/pages/overview.rst | 43 +- docs/source/pages/quickstart.rst | 8 +- src/torchmetrics/classification/accuracy.py | 310 +-------- src/torchmetrics/classification/auroc.py | 200 +----- .../classification/average_precision.py | 165 +---- .../classification/calibration_error.py | 133 +--- .../classification/cohen_kappa.py | 122 +--- .../classification/confusion_matrix.py | 153 +---- src/torchmetrics/classification/dice.py | 96 ++- src/torchmetrics/classification/f_beta.py | 373 ++--------- src/torchmetrics/classification/hamming.py | 124 +--- src/torchmetrics/classification/hinge.py | 153 +---- src/torchmetrics/classification/jaccard.py | 161 +---- .../classification/matthews_corrcoef.py | 127 +--- .../classification/precision_recall.py | 390 ++--------- .../classification/precision_recall_curve.py | 146 +---- .../recall_at_fixed_precision.py | 41 ++ src/torchmetrics/classification/roc.py | 169 +---- .../classification/specificity.py | 202 +----- .../classification/stat_scores.py | 285 +------- src/torchmetrics/collections.py | 56 +- src/torchmetrics/functional/__init__.py | 3 +- .../functional/classification/__init__.py | 1 - .../functional/classification/accuracy.py | 443 +------------ .../functional/classification/auroc.py | 279 +------- .../classification/average_precision.py | 260 +------- .../classification/calibration_error.py | 124 +--- .../functional/classification/cohen_kappa.py | 119 +--- .../classification/confusion_matrix.py | 192 +----- .../functional/classification/dice.py | 9 - .../functional/classification/f_beta.py | 415 ++---------- .../functional/classification/hamming.py | 124 +--- .../functional/classification/hinge.py | 250 +------ .../functional/classification/jaccard.py | 186 +----- .../classification/matthews_corrcoef.py | 109 +--- .../classification/precision_recall.py | 614 ++---------------- .../classification/precision_recall_curve.py | 300 +-------- .../recall_at_fixed_precision.py | 56 +- .../functional/classification/roc.py | 299 ++------- .../functional/classification/specificity.py | 232 +------ .../functional/classification/stat_scores.py | 212 +----- src/torchmetrics/utilities/checks.py | 8 +- src/torchmetrics/wrappers/bootstrapping.py | 6 +- src/torchmetrics/wrappers/classwise.py | 37 +- src/torchmetrics/wrappers/minmax.py | 8 +- src/torchmetrics/wrappers/tracker.py | 5 +- tests/integrations/test_lightning.py | 7 +- tests/unittests/bases/test_collections.py | 172 +++-- tests/unittests/bases/test_metric.py | 5 +- .../unittests/wrappers/test_bootstrapping.py | 7 +- tests/unittests/wrappers/test_classwise.py | 33 +- tests/unittests/wrappers/test_minmax.py | 9 +- tests/unittests/wrappers/test_multioutput.py | 6 +- tests/unittests/wrappers/test_tracker.py | 54 +- 58 files changed, 1434 insertions(+), 6648 deletions(-) delete mode 100644 docs/source/classification/precision_recall.rst diff --git a/CHANGELOG.md b/CHANGELOG.md index d1a215795cf..f0a27771cf9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Changed minimum Pytorch version to be 1.8 ([#1263](https://github.com/Lightning-AI/metrics/pull/1263)) +- Changed interface for all functional and modular classification metrics after refactor ([#1252](https://github.com/Lightning-AI/metrics/pull/1252)) + + ### Deprecated - diff --git a/README.md b/README.md index 44be9ae2d0b..c45228afe01 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ import torch import torchmetrics # initialize metric -metric = torchmetrics.Accuracy() +metric = torchmetrics.Accuracy(task="multiclass", num_classes=5) # move the metric to device you want computations to take place device = "cuda" if torch.cuda.is_available() else "cpu" @@ -169,7 +169,7 @@ def metric_ddp(rank, world_size): dist.init_process_group("gloo", rank=rank, world_size=world_size) # initialize model - metric = torchmetrics.Accuracy() + metric = torchmetrics.Accuracy(task="multiclass", num_classes=5) # define a model and append your metric to it # this allows metric states to be placed on correct accelerators when @@ -263,7 +263,9 @@ import torchmetrics preds = torch.randn(10, 5).softmax(dim=-1) target = torch.randint(5, (10,)) -acc = torchmetrics.functional.accuracy(preds, target) +acc = torchmetrics.functional.classification.multiclass_accuracy( + preds, target, num_classes=5 +) ``` ### Covered domains and example metrics diff --git a/docs/source/classification/precision_recall.rst b/docs/source/classification/precision_recall.rst deleted file mode 100644 index 9029206bd96..00000000000 --- a/docs/source/classification/precision_recall.rst +++ /dev/null @@ -1,16 +0,0 @@ -.. customcarditem:: - :header: Precision Recall - :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg - :tags: Classification - -.. include:: ../links.rst - -################ -Precision Recall -################ - -Functional Interface -____________________ - -.. autofunction:: torchmetrics.functional.precision_recall - :noindex: diff --git a/docs/source/pages/lightning.rst b/docs/source/pages/lightning.rst index eb8b3715ea5..6f4b8a30ffa 100644 --- a/docs/source/pages/lightning.rst +++ b/docs/source/pages/lightning.rst @@ -34,7 +34,7 @@ The example below shows how to use a metric in your `LightningModule Tensor: ) -class Accuracy(StatScores): - r"""Accuracy. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes Accuracy_: +class Accuracy: + r"""Computes `Accuracy`_ .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a - tensor of predictions. - - For multi-class and multi-dimensional multi-class data with probability or logits predictions, the - parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the - top-K highest probability or logit score items are considered to find the correct label. - - For multi-label and multi-dimensional multi-class inputs, this metric computes the "global" - accuracy by default, which counts all labels or sub-samples separately. This can be - changed to subset accuracy (which requires all labels or sub-samples in the sample to - be correctly predicted) by setting ``subset_accuracy=True``. - - Args: - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. + This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinaryAccuracy`, :mod:`MulticlassAccuracy` and :mod:`MultilabelAccuracy` for the specific details of + each argument influence and examples. - .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, - the value for the class will be ``nan``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - top_k: - Number of the highest probability or logit score predictions considered finding the correct label, - relevant only for (multi-dimensional) multi-class inputs. The - default value (``None``) will be interpreted as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. - - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - subset_accuracy: - Whether to compute subset accuracy for multi-label and multi-dimensional - multi-class inputs (has no effect for other input types). - - - For multi-label inputs, if the parameter is set to ``True``, then all labels for - each sample must be correctly predicted for the sample to count as correct. If it - is set to ``False``, then all labels are counted separately - this is equivalent to - flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). - - - For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all - sub-sample (on the extra axis) must be correct for the sample to be counted as correct. - If it is set to ``False``, then all sub-samples are counter separately - this is equivalent, - in the case of label predictions, to flattening the inputs beforehand (i.e. - ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter - still applies in both cases, if set. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - Raises: - ValueError: - If ``top_k`` is not an ``integer`` larger than ``0``. - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. - ValueError: - If two different input modes are provided, eg. using ``multi-label`` with ``multi-class``. - ValueError: - If ``top_k`` parameter is set for ``multi-label`` inputs. - - Example: + Legacy Example: >>> import torch - >>> from torchmetrics import Accuracy >>> target = torch.tensor([0, 1, 2, 3]) >>> preds = torch.tensor([0, 2, 1, 3]) - >>> accuracy = Accuracy() + >>> accuracy = Accuracy(task="multiclass", num_classes=4) >>> accuracy(preds, target) tensor(0.5000) >>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) - >>> accuracy = Accuracy(top_k=2) + >>> accuracy = Accuracy(task="multiclass", num_classes=3, top_k=2) >>> accuracy(preds, target) tensor(0.6667) """ - is_differentiable = False - higher_is_better = True - full_state_update: bool = False - correct: Tensor - total: Tensor def __new__( cls, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - subset_accuracy: bool = False, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_labels: Optional[int] = None, - multidim_average: Optional[Literal["global", "samplewise"]] = "global", - validate_args: bool = True, - **kwargs: Any, - ) -> Metric: - if task is not None: - assert multidim_average is not None - kwargs.update( - dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) - ) - if task == "binary": - return BinaryAccuracy(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassAccuracy(num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelAccuracy(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - threshold: float = 0.5, - num_classes: Optional[int] = None, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, + multidim_average: Literal["global", "samplewise"] = "global", + top_k: Optional[int] = 1, ignore_index: Optional[int] = None, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - subset_accuracy: bool = False, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, - num_labels: Optional[int] = None, - multidim_average: Optional[Literal["global", "samplewise"]] = "global", validate_args: bool = True, **kwargs: Any, - ) -> None: - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) - if "reduce" not in kwargs: - kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average - if "mdmc_reduce" not in kwargs: - kwargs["mdmc_reduce"] = mdmc_average - - super().__init__( - threshold=threshold, - top_k=top_k, - num_classes=num_classes, - multiclass=multiclass, - ignore_index=ignore_index, - **kwargs, + ) -> Metric: + kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryAccuracy(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassAccuracy(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelAccuracy(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - - if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): - raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") - - self.average = average - self.threshold = threshold - self.top_k = top_k - self.subset_accuracy = subset_accuracy - self.mode: DataType = None # type: ignore - self.multiclass = multiclass - self.ignore_index = ignore_index - - if self.subset_accuracy: - self.add_state("correct", default=tensor(0), dist_reduce_fx="sum") - self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - - Args: - preds: Predictions from model (logits, probabilities, or labels) - target: Ground truth labels - """ - """ returns the mode of the data (binary, multi label, multi class, multi-dim multi class) """ - mode = _mode(preds, target, self.threshold, self.top_k, self.num_classes, self.multiclass, self.ignore_index) - - if not self.mode: - self.mode = mode - elif self.mode != mode: - raise ValueError(f"You can not use {mode} inputs with {self.mode} inputs.") - - if self.subset_accuracy and not _check_subset_validity(self.mode): - self.subset_accuracy = False - - if self.subset_accuracy: - correct, total = _subset_accuracy_update( - preds, target, threshold=self.threshold, top_k=self.top_k, ignore_index=self.ignore_index - ) - self.correct += correct - self.total += total - else: - if not self.mode: - raise RuntimeError("You have to have determined mode.") - tp, fp, tn, fn = _accuracy_update( - preds, - target, - reduce=self.reduce, - mdmc_reduce=self.mdmc_reduce, - threshold=self.threshold, - num_classes=self.num_classes, - top_k=self.top_k, - multiclass=self.multiclass, - ignore_index=self.ignore_index, - mode=self.mode, - ) - - # Update states - if self.reduce != "samples" and self.mdmc_reduce != "samplewise": - self.tp += tp - self.fp += fp - self.tn += tn - self.fn += fn - else: - self.tp.append(tp) - self.fp.append(fp) - self.tn.append(tn) - self.fn.append(fn) - - def compute(self) -> Tensor: - """Computes accuracy based on inputs passed in to ``update`` previously.""" - if not self.mode: - raise RuntimeError("You have to have determined mode.") - if self.subset_accuracy: - return _subset_accuracy_compute(self.correct, self.total) - tp, fp, tn, fn = self._get_final_stats() - return _accuracy_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce, self.mode) diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index c2d77733989..14a6d9f3e88 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -23,8 +23,6 @@ MultilabelPrecisionRecallCurve, ) from torchmetrics.functional.classification.auroc import ( - _auroc_compute, - _auroc_update, _binary_auroc_arg_validation, _binary_auroc_compute, _multiclass_auroc_arg_validation, @@ -33,9 +31,7 @@ _multilabel_auroc_compute, ) from torchmetrics.metric import Metric -from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat -from torchmetrics.utilities.enums import DataType class BinaryAUROC(BinaryPrecisionRecallCurve): @@ -318,201 +314,55 @@ def compute(self) -> Tensor: return _multilabel_auroc_compute(state, self.num_labels, self.average, self.thresholds, self.ignore_index) -class AUROC(Metric): - r"""Area Under the Receiver Operating Characteristic Curve. +class AUROC: + r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_). The AUROC score summarizes the + ROC curve into an single number that describes the performance of a model for multiple thresholds at the same + time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing. - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. + This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinaryAUROC`, :mod:`MulticlassAUROC` and :mod:`MultilabelAUROC` for the specific details of + each argument influence and examples. - - Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_). - Works for both binary, multilabel and multiclass problems. In the case of - multiclass, the values will be calculated based on a one-vs-the-rest approach. - - Forward accepts - - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor - with probabilities, where C is the number of classes. - - - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels - - For non-binary input, if the ``preds`` and ``target`` tensor have the same - size the input will be interpretated as multilabel and if ``preds`` have one - dimension more than the ``target`` tensor the input will be interpretated as - multiclass. - - .. note:: - If either the positive class or negative class is completly missing in the target tensor, - the auroc score is meaningless in this case and a score of 0 will be returned together - with an warning. - - Args: - num_classes: integer with number of classes for multi-label and multiclass problems. - - Should be set to ``None`` for binary problems - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translated to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range ``[0, num_classes-1]`` - average: - - ``'micro'`` computes metric globally. Only works for multilabel problems - - ``'macro'`` computes metric for each class and uniformly averages them - - ``'weighted'`` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - - ``None`` computes and returns the metric per class - max_fpr: - If not ``None``, calculates standardized partial AUC over the - range ``[0, max_fpr]``. Should be a float between 0 and 1. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - Raises: - ValueError: - If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``. - ValueError: - If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``. - RuntimeError: - If ``PyTorch version`` is ``below 1.6`` since ``max_fpr`` requires ``torch.bucketize`` - which is not available below 1.6. - ValueError: - If the mode of data (binary, multi-label, multi-class) changes between batches. - - Example (binary case): - >>> from torchmetrics import AUROC + Legacy Example: >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) >>> target = torch.tensor([0, 0, 1, 1, 1]) - >>> auroc = AUROC(pos_label=1) + >>> auroc = AUROC(task="binary") >>> auroc(preds, target) tensor(0.5000) - Example (multiclass case): >>> preds = torch.tensor([[0.90, 0.05, 0.05], ... [0.05, 0.90, 0.05], ... [0.05, 0.05, 0.90], ... [0.85, 0.05, 0.10], ... [0.10, 0.10, 0.80]]) >>> target = torch.tensor([0, 1, 1, 2, 2]) - >>> auroc = AUROC(num_classes=3) + >>> auroc = AUROC(task="multiclass", num_classes=3) >>> auroc(preds, target) tensor(0.7778) """ - is_differentiable: bool = False - higher_is_better: bool = True - full_state_update: bool = False - preds: List[Tensor] - target: List[Tensor] def __new__( cls, + task: Literal["binary", "multiclass", "multilabel"], + thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, - pos_label: Optional[int] = None, + num_labels: Optional[int] = None, average: Optional[Literal["macro", "weighted", "none"]] = "macro", max_fpr: Optional[float] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, - thresholds: Optional[Union[int, List[float], Tensor]] = None, - num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task is not None: - kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryAUROC(max_fpr, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassAUROC(num_classes, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelAUROC(num_labels, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - max_fpr: Optional[float] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, - thresholds: Optional[Union[int, List[float], Tensor]] = None, - num_labels: Optional[int] = None, - ignore_index: Optional[int] = None, - validate_args: bool = True, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - self.num_classes = num_classes - self.pos_label = pos_label - self.average = average - self.max_fpr = max_fpr - - allowed_average = (None, "macro", "weighted", "micro") - if self.average not in allowed_average: - raise ValueError( - f"Argument `average` expected to be one of the following: {allowed_average} but got {average}" - ) - - if self.max_fpr is not None: - if not isinstance(max_fpr, float) or not 0 < max_fpr <= 1: - raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") - - self.mode: DataType = None # type: ignore - self.add_state("preds", default=[], dist_reduce_fx="cat") - self.add_state("target", default=[], dist_reduce_fx="cat") - - rank_zero_warn( - "Metric `AUROC` will save all targets and predictions in buffer." - " For large datasets this may lead to large memory footprint." - ) - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - - Args: - preds: Predictions from model (probabilities, or labels) - target: Ground truth labels - """ - preds, target, mode = _auroc_update(preds, target) - - self.preds.append(preds) - self.target.append(target) - - if self.mode and self.mode != mode: - raise ValueError( - "The mode of data (binary, multi-label, multi-class) should be constant, but changed" - f" between batches from {self.mode} to {mode}" - ) - self.mode = mode - - def compute(self) -> Tensor: - """Computes AUROC based on inputs passed in to ``update`` previously.""" - if not self.mode: - raise RuntimeError("You have to have determined mode.") - preds = dim_zero_cat(self.preds) - target = dim_zero_cat(self.target) - return _auroc_compute( - preds, - target, - self.mode, - self.num_classes, - self.pos_label, - self.average, - self.max_fpr, + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryAUROC(max_fpr, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassAUROC(num_classes, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelAUROC(num_labels, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index cd19d76046e..be736fa35fe 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -23,8 +23,6 @@ MultilabelPrecisionRecallCurve, ) from torchmetrics.functional.classification.average_precision import ( - _average_precision_compute, - _average_precision_update, _binary_average_precision_compute, _multiclass_average_precision_arg_validation, _multiclass_average_precision_compute, @@ -32,7 +30,6 @@ _multilabel_average_precision_compute, ) from torchmetrics.metric import Metric -from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat @@ -318,162 +315,58 @@ def compute(self) -> Tensor: ) -class AveragePrecision(Metric): - r"""Average Precision. +class AveragePrecision: + r"""Computes the average precision (AP) score. The AP score summarizes a precision-recall curve as an weighted + mean of precisions at each threshold, with the difference in recall from the previous threshold as weight: - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes the average precision score, which summarises the precision recall curve into one number. Works for - both binary and multiclass problems. In the case of multiclass, the values will be calculated based on a one- - vs-the-rest approach. - - Forward accepts - - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor - with probabilities, where C is the number of classes. - - - ``target`` (long tensor): ``(N, ...)`` with integer labels + .. math:: + AP = \sum{n} (R_n - R_{n-1}) P_n - Args: - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translated to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range ``[0, num_classes-1]`` - average: - defines the reduction that is applied in the case of multiclass and multilabel input. - Should be one of the following: - - - ``'macro'`` [default]: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'micro'``: Calculate the metric globally, across all samples and classes. Cannot be - used with multiclass input. - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support. - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. + where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is + equivalent to the area under the precision-recall curve (AUPRC). - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinaryAveragePrecision`, :mod:`MulticlassAveragePrecision` and :mod:`MultilabelAveragePrecision` + for the specific details of each argument influence and examples. - Example (binary case): - >>> from torchmetrics import AveragePrecision + Legacy Example: >>> pred = torch.tensor([0, 0.1, 0.8, 0.4]) >>> target = torch.tensor([0, 1, 1, 1]) - >>> average_precision = AveragePrecision(pos_label=1) + >>> average_precision = AveragePrecision(task="binary") >>> average_precision(pred, target) tensor(1.) - Example (multiclass case): >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> average_precision = AveragePrecision(num_classes=5, average=None) + >>> average_precision = AveragePrecision(task="multiclass", num_classes=5, average=None) >>> average_precision(pred, target) - [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] + tensor([1.0000, 1.0000, 0.2500, 0.2500, nan]) """ - is_differentiable: bool = False - higher_is_better: Optional[bool] = None - full_state_update: bool = False - preds: List[Tensor] - target: List[Tensor] - def __new__( cls, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - average: Optional[Literal["macro", "weighted", "none"]] = "macro", - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + task: Literal["binary", "multiclass", "multilabel"], thresholds: Optional[Union[int, List[float], Tensor]] = None, - num_labels: Optional[int] = None, - ignore_index: Optional[int] = None, - validate_args: bool = True, - **kwargs: Any, - ) -> Metric: - if task is not None: - kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryAveragePrecision(**kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassAveragePrecision(num_classes, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelAveragePrecision(num_labels, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, - thresholds: Optional[Union[int, List[float], Tensor]] = None, num_labels: Optional[int] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - self.num_classes = num_classes - self.pos_label = pos_label - allowed_average = ("micro", "macro", "weighted", "none", None) - if average not in allowed_average: - raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}") - self.average = average - - self.add_state("preds", default=[], dist_reduce_fx="cat") - self.add_state("target", default=[], dist_reduce_fx="cat") - - rank_zero_warn( - "Metric `AveragePrecision` will save all targets and predictions in buffer." - " For large datasets this may lead to large memory footprint." - ) - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - preds, target, num_classes, pos_label = _average_precision_update( - preds, target, self.num_classes, self.pos_label, self.average + ) -> Metric: + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryAveragePrecision(**kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassAveragePrecision(num_classes, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelAveragePrecision(num_labels, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - self.preds.append(preds) - self.target.append(target) - self.num_classes = num_classes - self.pos_label = pos_label - - def compute(self) -> Union[Tensor, List[Tensor]]: - """Compute the average precision score. - - Returns: - tensor with average precision. If multiclass return list of such tensors, one for each class - """ - preds = dim_zero_cat(self.preds) - target = dim_zero_cat(self.target) - if not self.num_classes: - raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}") - return _average_precision_compute(preds, target, self.num_classes, self.pos_label, self.average) diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 297dacf8325..f863addbe8b 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional +from typing import Any, Optional import torch from torch import Tensor @@ -23,7 +23,6 @@ _binary_calibration_error_update, _binary_confusion_matrix_format, _ce_compute, - _ce_update, _multiclass_calibration_error_arg_validation, _multiclass_calibration_error_tensor_validation, _multiclass_calibration_error_update, @@ -31,7 +30,6 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat -from torchmetrics.utilities.prints import rank_zero_warn class BinaryCalibrationError(Metric): @@ -222,127 +220,48 @@ def compute(self) -> Tensor: return _ce_compute(confidences, accuracies, self.n_bins, norm=self.norm) -class CalibrationError(Metric): - r"""Calibration Error. +class CalibrationError: + r"""`Computes the Top-label Calibration Error`_. The expected calibration error can be used to quantify how well + a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual + probabilities of the ground truth distribution. - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - `Computes the Top-label Calibration Error`_ Three different norms are implemented, each corresponding to variations on the calibration error metric. - L1 norm (Expected Calibration Error) - .. math:: - \text{ECE} = \sum_i^N b_i \|(p_i - c_i)\| - - Infinity norm (Maximum Calibration Error) + \text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)} .. math:: - \text{MCE} = \max_{i} (p_i - c_i) - - L2 norm (Root Mean Square Calibration Error) + \text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)} .. math:: - \text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2} - - Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, - :math:`c_i` is the average confidence of predictions in bin :math:`i`, and - :math:`b_i` is the fraction of data points in bin :math:`i`. - - .. note:: - L2-norm debiasing is not yet supported. + \text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)} - Args: - n_bins: Number of bins to use when computing probabilities and accuracies. - norm: Norm used to compare empirical and expected probability bins. - Defaults to "l1", or Expected Calibration Error. - debias: Applies debiasing term, only implemented for l2 norm. Defaults to True. + Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, :math:`c_i` is the average confidence of + predictions in bin :math:`i`, and :math:`b_i` is the fraction of data points in bin :math:`i`. Bins are constructed + in an uniform way in the [0,1] range. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'`` or ``'multiclass'``. See the documentation of + :mod:`BinaryCalibrationError` and :mod:`MulticlassCalibrationError` for the specific details of + each argument influence and examples. """ - is_differentiable: bool = False - higher_is_better: bool = False - full_state_update: bool = False - DISTANCES = {"l1", "l2", "max"} - confidences: List[Tensor] - accuracies: List[Tensor] def __new__( cls, + task: Literal["binary", "multiclass"] = None, n_bins: int = 15, - norm: str = "l1", - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + norm: Literal["l1", "l2", "max"] = "l1", num_classes: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task is not None: - kwargs.update(dict(n_bins=n_bins, norm=norm, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryCalibrationError(**kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassCalibrationError(num_classes, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - n_bins: int = 15, - norm: str = "l1", - **kwargs: Any, - ): - - super().__init__(**kwargs) - - if norm not in self.DISTANCES: - raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") - - if not isinstance(n_bins, int) or n_bins <= 0: - raise ValueError(f"Expected argument `n_bins` to be a int larger than 0 but got {n_bins}") - self.n_bins = n_bins - self.bin_boundaries = torch.linspace(0, 1, n_bins + 1) - self.norm = norm - - self.add_state("confidences", [], dist_reduce_fx="cat") - self.add_state("accuracies", [], dist_reduce_fx="cat") - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Computes top-level confidences and accuracies for the input probabilities and appends them to internal - state. - - Args: - preds (Tensor): Model output probabilities. - target (Tensor): Ground-truth target class labels. - """ - confidences, accuracies = _ce_update(preds, target) - - self.confidences.append(confidences) - self.accuracies.append(accuracies) - - def compute(self) -> Tensor: - """Computes calibration error across all confidences and accuracies. - - Returns: - Tensor: Calibration error across previously collected examples. - """ - confidences = dim_zero_cat(self.confidences) - accuracies = dim_zero_cat(self.accuracies) - return _ce_compute(confidences, accuracies, self.bin_boundaries.to(self.device), norm=self.norm) + kwargs.update(dict(n_bins=n_bins, norm=norm, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryCalibrationError(**kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassCalibrationError(num_classes, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 231617e4d1c..70679526ce1 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -20,13 +20,10 @@ from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix from torchmetrics.functional.classification.cohen_kappa import ( _binary_cohen_kappa_arg_validation, - _cohen_kappa_compute, _cohen_kappa_reduce, - _cohen_kappa_update, _multiclass_cohen_kappa_arg_validation, ) from torchmetrics.metric import Metric -from torchmetrics.utilities.prints import rank_zero_warn class BinaryCohenKappa(BinaryConfusionMatrix): @@ -180,17 +177,8 @@ def compute(self) -> Tensor: return _cohen_kappa_reduce(self.confmat, self.weights) -class CohenKappa(Metric): - r"""Cohen Kappa. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Calculates `Cohen's kappa score`_ that measures inter-annotator agreement. It is defined as +class CohenKappa: + r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement. It is defined as. .. math:: \kappa = (p_o - p_e) / (1 - p_e) @@ -200,105 +188,35 @@ class CohenKappa(Metric): :math:`p_e` is estimated using a per-annotator empirical prior over the class labels. - Works with binary, multiclass, and multilabel data. Accepts probabilities from a model output or - integer class values in prediction. Works with multi-dimensional preds and target. - - Forward accepts - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - - ``target`` (long tensor): ``(N, ...)`` - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities or logits. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - Args: - num_classes: Number of classes in the dataset. - weights: Weighting type to calculate the score. Choose from: - - - ``None`` or ``'none'``: no weighting - - ``'linear'``: linear weighting - - ``'quadratic'``: quadratic weighting - - threshold: - Threshold for transforming probability or logit predictions to binary ``(0,1)`` predictions, in the case - of binary or multi-label inputs. Default value of ``0.5`` corresponds to input being probabilities. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'`` or ``'multiclass'``. See the documentation of + :mod:`BinaryCohenKappa` and :mod:`MulticlassCohenKappa` for the specific details of + each argument influence and examples. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - Example: - >>> from torchmetrics import CohenKappa + Legacy Example: >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) - >>> cohenkappa = CohenKappa(num_classes=2) + >>> cohenkappa = CohenKappa(task="multiclass", num_classes=2) >>> cohenkappa(preds, target) tensor(0.5000) """ - is_differentiable: bool = False - higher_is_better: bool = True - full_state_update: bool = False - confmat: Tensor def __new__( cls, - num_classes: Optional[int] = None, - weights: Optional[str] = None, + task: Literal["binary", "multiclass"], threshold: float = 0.5, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task is not None: - kwargs.update(dict(weights=weights, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryCohenKappa(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassCohenKappa(num_classes, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - num_classes: int, - weights: Optional[str] = None, - threshold: float = 0.5, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - self.num_classes = num_classes - self.weights = weights - self.threshold = threshold - - allowed_weights = ("linear", "quadratic", "none", None) - if self.weights not in allowed_weights: - raise ValueError(f"Argument weights needs to one of the following: {allowed_weights}") - - self.add_state("confmat", default=torch.zeros(num_classes, num_classes, dtype=torch.long), dist_reduce_fx="sum") - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - confmat = _cohen_kappa_update(preds, target, self.num_classes, self.threshold) - self.confmat += confmat - - def compute(self) -> Tensor: - """Computes cohen kappa score.""" - return _cohen_kappa_compute(self.confmat, self.weights) + kwargs.update(dict(weights=weights, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryCohenKappa(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassCohenKappa(num_classes, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 9a2d6c8486d..e2fd33b9105 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -23,8 +23,6 @@ _binary_confusion_matrix_format, _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_update, - _confusion_matrix_compute, - _confusion_matrix_update, _multiclass_confusion_matrix_arg_validation, _multiclass_confusion_matrix_compute, _multiclass_confusion_matrix_format, @@ -37,7 +35,6 @@ _multilabel_confusion_matrix_update, ) from torchmetrics.metric import Metric -from torchmetrics.utilities.prints import rank_zero_warn class BinaryConfusionMatrix(Metric): @@ -315,159 +312,59 @@ def compute(self) -> Tensor: return _multilabel_confusion_matrix_compute(self.confmat, self.normalize) -class ConfusionMatrix(Metric): - r"""Confusion Matrix. +class ConfusionMatrix: + r"""Computes the `confusion matrix`_. - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinaryConfusionMatrix`, :mod:`MulticlassConfusionMatrix` and :func:`MultilabelConfusionMatrix` for + the specific details of each argument influence and examples. - Computes the `confusion matrix`_. - - Works with binary, multiclass, and multilabel data. Accepts probabilities or logits from a model output - or integer class values in prediction. Works with multi-dimensional preds and target, but it should be noted that - additional dimensions will be flattened. - - Forward accepts - - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities or logits. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - If working with multilabel data, setting the ``is_multilabel`` argument to ``True`` will make sure that a - `confusion matrix gets calculated per label`_. - - Args: - num_classes: Number of classes in the dataset. - normalize: Normalization mode for confusion matrix. Choose from: - - - ``None`` or ``'none'``: no normalization (default) - - ``'true'``: normalization over the targets (most commonly used) - - ``'pred'``: normalization over the predictions - - ``'all'``: normalization over the whole matrix - - threshold: - Threshold for transforming probability or logit predictions to binary ``(0,1)`` predictions, in the case - of binary or multi-label inputs. Default value of ``0.5`` corresponds to input being probabilities. - - multilabel: determines if data is multilabel or not. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - Example (binary data): - >>> from torchmetrics import ConfusionMatrix + Legacy Example: >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) - >>> confmat = ConfusionMatrix(num_classes=2) + >>> confmat = ConfusionMatrix(task="binary", num_classes=2) >>> confmat(preds, target) tensor([[2, 0], [1, 1]]) - Example (multiclass data): >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) - >>> confmat = ConfusionMatrix(num_classes=3) + >>> confmat = ConfusionMatrix(task="multiclass", num_classes=3) >>> confmat(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]]) - Example (multilabel data): >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) - >>> confmat = ConfusionMatrix(num_classes=3, multilabel=True) + >>> confmat = ConfusionMatrix(task="multilabel", num_labels=3) >>> confmat(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]]) """ - is_differentiable: bool = False - higher_is_better: Optional[bool] = None - full_state_update: bool = False - confmat: Tensor def __new__( cls, - num_classes: Optional[int] = None, - normalize: Optional[str] = None, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - multilabel: bool = False, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task is not None: - kwargs.update(dict(normalize=normalize, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryConfusionMatrix(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassConfusionMatrix(num_classes, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelConfusionMatrix(num_labels, threshold, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - num_classes: int, - normalize: Optional[str] = None, - threshold: float = 0.5, - multilabel: bool = False, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - self.num_classes = num_classes - self.normalize = normalize - self.threshold = threshold - self.multilabel = multilabel - - allowed_normalize = ("true", "pred", "all", "none", None) - if self.normalize not in allowed_normalize: - raise ValueError(f"Argument average needs to one of the following: {allowed_normalize}") - - if multilabel: - default = torch.zeros(num_classes, 2, 2, dtype=torch.long) - else: - default = torch.zeros(num_classes, num_classes, dtype=torch.long) - self.add_state("confmat", default=default, dist_reduce_fx="sum") - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold, self.multilabel) - self.confmat += confmat - - def compute(self) -> Tensor: - """Computes confusion matrix. - - Returns: - If ``multilabel=False`` this will be a ``[n_classes, n_classes]`` tensor and if ``multilabel=True`` - this will be a ``[n_classes, 2, 2]`` tensor. - """ - return _confusion_matrix_compute(self.confmat, self.normalize) + kwargs.update(dict(normalize=normalize, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryConfusionMatrix(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassConfusionMatrix(num_classes, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelConfusionMatrix(num_labels, threshold, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 9e025dac564..926770f6b54 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -11,17 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Callable, Optional, Tuple, no_type_check +import torch from torch import Tensor from typing_extensions import Literal -from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.dice import _dice_compute -from torchmetrics.utilities.enums import AverageMethod +from torchmetrics.functional.classification.stat_scores import _stat_scores_update +from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod -class Dice(StatScores): +class Dice(Metric): r"""Computes `Dice`_: .. math:: \text{Dice} = \frac{\text{2 * TP}}{\text{2 * TP} + \text{FP} + \text{FN}} @@ -117,6 +119,7 @@ class Dice(StatScores): higher_is_better: bool = True full_state_update: bool = False + @no_type_check def __init__( self, zero_division: int = 0, @@ -129,6 +132,7 @@ def __init__( multiclass: Optional[bool] = None, **kwargs: Any, ) -> None: + super().__init__(**kwargs) allowed_average = ("micro", "macro", "weighted", "samples", "none", None) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -139,18 +143,86 @@ def __init__( if "mdmc_reduce" not in kwargs: kwargs["mdmc_reduce"] = mdmc_average - super().__init__( - threshold=threshold, - top_k=top_k, - num_classes=num_classes, - multiclass=multiclass, - ignore_index=ignore_index, - **kwargs, - ) + self.reduce = average + self.mdmc_reduce = mdmc_average + self.num_classes = num_classes + self.threshold = threshold + self.multiclass = multiclass + self.ignore_index = ignore_index + self.top_k = top_k + + if average not in ["micro", "macro", "samples"]: + raise ValueError(f"The `reduce` {average} is not valid.") + + if mdmc_average not in [None, "samplewise", "global"]: + raise ValueError(f"The `mdmc_reduce` {mdmc_average} is not valid.") + + if average == "macro" and (not num_classes or num_classes < 1): + raise ValueError("When you set `average` as 'macro', you have to provide the number of classes.") + + if num_classes and ignore_index is not None and (not ignore_index < num_classes or num_classes == 1): + raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") + + default: Callable = lambda: [] + reduce_fn: Optional[str] = "cat" + if mdmc_average != "samplewise" and average != "samples": + if average == "micro": + zeros_shape = [] + elif average == "macro": + zeros_shape = [num_classes] + else: + raise ValueError(f'Wrong reduce="{average}"') + default = lambda: torch.zeros(zeros_shape, dtype=torch.long) + reduce_fn = "sum" + + for s in ("tp", "fp", "tn", "fn"): + self.add_state(s, default=default(), dist_reduce_fx=reduce_fn) self.average = average self.zero_division = zero_division + @no_type_check + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets. + + Args: + preds: Predictions from model (probabilities, logits or labels) + target: Ground truth values + """ + tp, fp, tn, fn = _stat_scores_update( + preds, + target, + reduce=self.reduce, + mdmc_reduce=self.mdmc_reduce, + threshold=self.threshold, + num_classes=self.num_classes, + top_k=self.top_k, + multiclass=self.multiclass, + ignore_index=self.ignore_index, + ) + + # Update states + if self.reduce != AverageMethod.SAMPLES and self.mdmc_reduce != MDMCAverageMethod.SAMPLEWISE: + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + else: + self.tp.append(tp) + self.fp.append(fp) + self.tn.append(tn) + self.fn.append(fn) + + @no_type_check + def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Performs concatenation on the stat scores if neccesary, before passing them to a compute function.""" + tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp + fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp + tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn + fn = torch.cat(self.fn) if isinstance(self.fn, list) else self.fn + return tp, fp, tn, fn + + @no_type_check def compute(self) -> Tensor: """Computes the dice score based on inputs passed in to ``update`` previously. diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 19336dcc3f7..49fd39005c5 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -17,22 +17,14 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics.classification.stat_scores import ( - BinaryStatScores, - MulticlassStatScores, - MultilabelStatScores, - StatScores, -) +from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.f_beta import ( _binary_fbeta_score_arg_validation, - _fbeta_compute, _fbeta_reduce, _multiclass_fbeta_score_arg_validation, _multilabel_fbeta_score_arg_validation, ) from torchmetrics.metric import Metric -from torchmetrics.utilities.enums import AverageMethod -from torchmetrics.utilities.prints import rank_zero_warn class BinaryFBetaScore(BinaryStatScores): @@ -708,354 +700,101 @@ def __init__( ) -class FBetaScore(StatScores): - r"""F-Beta Score. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes `F-score`_, specifically: +class FBetaScore: + r"""Computes `F-score`_ metric: .. math:: - F_\beta = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} + F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}} - Where :math:`\beta` is some positive real factor. Works with binary, multiclass, and multilabel data. - Accepts logit scores or probabilities from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. - - Forward accepts - - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label logits and probabilities. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - Args: - num_classes: Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - beta: Beta coefficient in the F measure. - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, - the value for the class will be ``nan``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - top_k: - Number of the highest probability or logit score predictions considered finding the correct label, - relevant only for (multi-dimensional) multi-class inputs. The default value (``None``) will be interpreted - as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. - - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_fbeta_score`, :func:`multiclass_fbeta_score` and :func:`multilabel_fbeta_score` for the specific + details of each argument influence and examples. - Raises: - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"``, ``None``. - - Example: + Legcy Example: >>> import torch - >>> from torchmetrics import FBetaScore >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> f_beta = FBetaScore(num_classes=3, beta=0.5) + >>> f_beta = FBetaScore(task="multiclass", num_classes=3, beta=0.5) >>> f_beta(preds, target) tensor(0.3333) """ - full_state_update: bool = False def __new__( cls, - num_classes: Optional[int] = None, + task: Literal["binary", "multiclass", "multilabel"], beta: float = 1.0, threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task is not None: - assert multidim_average is not None - kwargs.update( - dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) - ) - if task == "binary": - return BinaryFBetaScore(beta, threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - num_classes: Optional[int] = None, - beta: float = 1.0, - threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - **kwargs: Any, - ) -> None: - self.beta = beta - allowed_average = list(AverageMethod) - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) - if "reduce" not in kwargs: - kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average - if "mdmc_reduce" not in kwargs: - kwargs["mdmc_reduce"] = mdmc_average - - super().__init__( - threshold=threshold, - top_k=top_k, - num_classes=num_classes, - multiclass=multiclass, - ignore_index=ignore_index, - **kwargs, + assert multidim_average is not None + kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryFBetaScore(beta, threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - self.average = average - - def compute(self) -> Tensor: - """Computes f-beta over state.""" - tp, fp, tn, fn = self._get_final_stats() - return _fbeta_compute(tp, fp, tn, fn, self.beta, self.ignore_index, self.average, self.mdmc_reduce) - - -class F1Score(FBetaScore): - r"""F1 Score. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes F1 metric. - - F1 metrics correspond to a harmonic mean of the precision and recall scores. - Works with binary, multiclass, and multilabel data. Accepts logits or probabilities from a model - output or integer class values in prediction. Works with multi-dimensional preds and target. - - Forward accepts - - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. - This is the case for binary and multi-label logits. - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. +class F1Score: + r"""Computes F-1 score: - Args: - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - threshold: - Threshold for transforming probability or logit predictions to binary ``(0,1)`` predictions, in the case - of binary or multi-label inputs. Default value of ``0.5`` corresponds to input being probabilities. - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - top_k: - Number of the highest probability or logit score predictions considered finding the correct label, - relevant only for (multi-dimensional) multi-class inputs. The - default value (``None``) will be interpreted as 1 for these inputs. - Should be left at default (``None``) for all other types of inputs. - - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + .. math:: + F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinaryF1Score`, :mod:`MulticlassF1Score` and :mod:`MultilabelF1Score` for the specific + details of each argument influence and examples. - Example: + Legacy Example: >>> import torch - >>> from torchmetrics import F1Score >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> f1 = F1Score(num_classes=3) + >>> f1 = F1Score(task="multiclass", num_classes=3) >>> f1(preds, target) tensor(0.3333) """ - is_differentiable: bool = False - higher_is_better: bool = True - full_state_update: bool = False - def __new__( cls, - num_classes: Optional[int] = None, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task is not None: - assert multidim_average is not None - kwargs.update( - dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) - ) - if task == "binary": - return BinaryF1Score(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassF1Score(num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelF1Score(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - num_classes: Optional[int] = None, - threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - **kwargs: Any, - ) -> None: - super().__init__( - num_classes=num_classes, - beta=1.0, - threshold=threshold, - average=average, - mdmc_average=mdmc_average, - ignore_index=ignore_index, - top_k=top_k, - multiclass=multiclass, - **kwargs, + assert multidim_average is not None + kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryF1Score(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassF1Score(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelF1Score(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index ab04b26aa0a..cbb7a5ce987 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -14,17 +14,12 @@ from typing import Any, Optional import torch -from torch import Tensor, tensor +from torch import Tensor from typing_extensions import Literal from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores -from torchmetrics.functional.classification.hamming import ( - _hamming_distance_compute, - _hamming_distance_reduce, - _hamming_distance_update, -) +from torchmetrics.functional.classification.hamming import _hamming_distance_reduce from torchmetrics.metric import Metric -from torchmetrics.utilities.prints import rank_zero_warn class BinaryHammingDistance(BinaryStatScores): @@ -317,119 +312,54 @@ def compute(self) -> Tensor: ) -class HammingDistance(Metric): - r"""Hamming distance. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes the average `Hamming distance`_ (also known as Hamming loss) between targets and predictions: +class HammingDistance: + r"""Computes the average `Hamming distance`_ (also known as Hamming loss): .. math:: - \text{Hamming distance} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}}) + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that tensor. - This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it - treats each possible label separately - meaning that, for example, multi-class data is - treated as if it were multi-label. - - Args: - threshold: - Threshold for transforming probability or logit predictions to binary ``(0,1)`` predictions, in the case - of binary or multi-label inputs. Default value of ``0.5`` corresponds to input being probabilities. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - Raises: - ValueError: - If ``threshold`` is not between ``0`` and ``1``. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinaryHammingDistance`, :mod:`MulticlassHammingDistance` and :mod:`MultilabelHammingDistance` for the + specific details of each argument influence and examples. - Example: - >>> from torchmetrics import HammingDistance + Legacy Example: >>> target = torch.tensor([[0, 1], [1, 1]]) >>> preds = torch.tensor([[0, 1], [0, 1]]) - >>> hamming_distance = HammingDistance() + >>> hamming_distance = HammingDistance(task="multilabel", num_labels=2) >>> hamming_distance(preds, target) tensor(0.2500) """ - is_differentiable: bool = False - higher_is_better: bool = False - full_state_update: bool = False - correct: Tensor - total: Tensor def __new__( cls, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", - top_k: Optional[int] = None, + top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task is not None: - assert multidim_average is not None - kwargs.update( - dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) - ) - if task == "binary": - return BinaryHammingDistance(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassHammingDistance(num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelHammingDistance(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - threshold: float = 0.5, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - self.add_state("correct", default=tensor(0), dist_reduce_fx="sum") - self.add_state("total", default=tensor(0), dist_reduce_fx="sum") - - self.threshold = threshold - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - Args: - preds: Predictions from model (probabilities, logits or labels) - target: Ground truth labels - """ - correct, total = _hamming_distance_update(preds, target, self.threshold) - - self.correct += correct - self.total += total - - def compute(self) -> Tensor: - """Computes hamming distance based on inputs passed in to ``update`` previously.""" - return _hamming_distance_compute(self.correct, self.total) + assert multidim_average is not None + kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryHammingDistance(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassHammingDistance(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelHammingDistance(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index d80f6085c76..1c8b05986ab 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -11,28 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Optional import torch from torch import Tensor from typing_extensions import Literal from torchmetrics.functional.classification.hinge import ( - MulticlassMode, _binary_confusion_matrix_format, _binary_hinge_loss_arg_validation, _binary_hinge_loss_tensor_validation, _binary_hinge_loss_update, - _hinge_compute, _hinge_loss_compute, - _hinge_update, _multiclass_confusion_matrix_format, _multiclass_hinge_loss_arg_validation, _multiclass_hinge_loss_tensor_validation, _multiclass_hinge_loss_update, ) from torchmetrics.metric import Metric -from torchmetrics.utilities.prints import rank_zero_warn class BinaryHingeLoss(Metric): @@ -205,146 +201,51 @@ def compute(self) -> Tensor: return _hinge_loss_compute(self.measures, self.total) -class HingeLoss(Metric): - r"""Hinge Loss. +class HingeLoss: + r"""Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs). - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'`` or ``'multiclass'``. See the documentation of + :mod:`BinaryHingeLoss` and :mod:`MulticlassHingeLoss` for the specific details of + each argument influence and examples. - Computes the mean `Hinge loss`_, typically used for Support Vector Machines (SVMs). - - In the binary case it is defined as: - - .. math:: - \text{Hinge loss} = \max(0, 1 - y \times \hat{y}) - - Where :math:`y \in {-1, 1}` is the target, and :math:`\hat{y} \in \mathbb{R}` is the prediction. - - In the multi-class case, when ``multiclass_mode=None`` (default), ``multiclass_mode=MulticlassMode.CRAMMER_SINGER`` - or ``multiclass_mode="crammer-singer"``, this metric will compute the multi-class hinge loss defined by Crammer and - Singer as: - - .. math:: - \text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right) - - Where :math:`y \in {0, ..., \mathrm{C}}` is the target class (where :math:`\mathrm{C}` is the number of classes), - and :math:`\hat{y} \in \mathbb{R}^\mathrm{C}` is the predicted output per class. - - In the multi-class case when ``multiclass_mode=MulticlassMode.ONE_VS_ALL`` or ``multiclass_mode='one-vs-all'``, this - metric will use a one-vs-all approach to compute the hinge loss, giving a vector of C outputs where each entry pits - that class against all remaining classes. - - This metric can optionally output the mean of the squared hinge loss by setting ``squared=True`` - - Only accepts inputs with preds shape of (N) (binary) or (N, C) (multi-class) and target shape of (N). - - Args: - squared: - If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss (default). - multiclass_mode: - Which approach to use for multi-class inputs (has no effect in the binary case). ``None`` (default), - ``MulticlassMode.CRAMMER_SINGER`` or ``"crammer-singer"``, uses the Crammer Singer multi-class hinge loss. - ``MulticlassMode.ONE_VS_ALL`` or ``"one-vs-all"`` computes the hinge loss in a one-vs-all fashion. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - - Raises: - ValueError: - If ``multiclass_mode`` is not: None, ``MulticlassMode.CRAMMER_SINGER``, ``"crammer-singer"``, - ``MulticlassMode.ONE_VS_ALL`` or ``"one-vs-all"``. - - Example (binary case): + Legacy Example: >>> import torch - >>> from torchmetrics import HingeLoss >>> target = torch.tensor([0, 1, 1]) - >>> preds = torch.tensor([-2.2, 2.4, 0.1]) - >>> hinge = HingeLoss() + >>> preds = torch.tensor([0.5, 0.7, 0.1]) + >>> hinge = HingeLoss(task="binary") >>> hinge(preds, target) - tensor(0.3000) + tensor(0.9000) - Example (default / multiclass case): >>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) - >>> hinge = HingeLoss() + >>> hinge = HingeLoss(task="multiclass", num_classes=3) >>> hinge(preds, target) - tensor(2.9000) + tensor(1.5551) - Example (multiclass example, one vs all mode): >>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) - >>> hinge = HingeLoss(multiclass_mode="one-vs-all") + >>> hinge = HingeLoss(task="multiclass", num_classes=3, multiclass_mode="one-vs-all") >>> hinge(preds, target) - tensor([2.2333, 1.5000, 1.2333]) + tensor([1.3743, 1.1945, 1.2359]) """ - is_differentiable: bool = True - higher_is_better: bool = False - full_state_update: bool = False - measure: Tensor - total: Tensor def __new__( cls, - squared: bool = False, - multiclass_mode: Literal["crammer-singer", "one-vs-all"] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + task: Literal["binary", "multiclass"], num_classes: Optional[int] = None, + squared: bool = False, + multiclass_mode: Optional[Literal["crammer-singer", "one-vs-all"]] = "crammer-singer", ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task is not None: - kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryHingeLoss(squared, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert multiclass_mode is not None - return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - squared: bool = False, - multiclass_mode: Optional[Union[str, MulticlassMode]] = None, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - self.add_state("measure", default=torch.tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") - - if multiclass_mode not in (None, MulticlassMode.CRAMMER_SINGER, MulticlassMode.ONE_VS_ALL): - raise ValueError( - "The `multiclass_mode` should be either None / 'crammer-singer' / MulticlassMode.CRAMMER_SINGER" - "(default) or 'one-vs-all' / MulticlassMode.ONE_VS_ALL," - f" got {multiclass_mode}." - ) - - self.squared = squared - self.multiclass_mode = multiclass_mode - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - measure, total = _hinge_update(preds, target, squared=self.squared, multiclass_mode=self.multiclass_mode) - - self.measure = measure + self.measure - self.total = total + self.total - - def compute(self) -> Tensor: - return _hinge_compute(self.measure, self.total) + kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryHingeLoss(squared, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index a4e7e10ed35..c637a8a94b6 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -18,15 +18,12 @@ from typing_extensions import Literal from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix -from torchmetrics.classification.confusion_matrix import ConfusionMatrix from torchmetrics.functional.classification.jaccard import ( - _jaccard_from_confmat, _jaccard_index_reduce, _multiclass_jaccard_index_arg_validation, _multilabel_jaccard_index_arg_validation, ) from torchmetrics.metric import Metric -from torchmetrics.utilities.prints import rank_zero_warn class BinaryJaccardIndex(BinaryConfusionMatrix): @@ -255,154 +252,48 @@ def compute(self) -> Tensor: return _jaccard_index_reduce(self.confmat, average=self.average) -class JaccardIndex(ConfusionMatrix): - r"""Jaccard Index. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes Intersection over union, or `Jaccard index`_: +class JaccardIndex: + r"""Calculates the Jaccard index for multilabel tasks. The `Jaccard index`_ (also known as the intersetion over + union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and + diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample + sets: .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} - Where: :math:`A` and :math:`B` are both tensors of the same size, containing integer class values. - They may be subject to conversion from input data (see description below). Note that it is different from box IoU. - - Works with binary, multiclass and multi-label data. - Accepts probabilities from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. - - Forward accepts - - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinaryJaccardIndex`, :mod:`MulticlassJaccardIndex` and :mod:`MultilabelJaccardIndex` for + the specific details of each argument influence and examples. - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - Args: - num_classes: Number of classes in the dataset. - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'macro'`` [default]: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'micro'``: Calculate the metric globally, across all samples and classes. - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. Note that if a given class doesn't occur in the - `preds` or `target`, the value for the class will be ``nan``. - - ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. Has no effect if given an int that is not in the - range [0, num_classes-1]. By default, no index is ignored, and all classes are used. - absent_score: score to use for an individual class, if no instances of the class index were present in - ``preds`` AND no instances of the class index were present in ``target``. For example, if we have 3 classes, - [0, 0] for ``preds``, and [0, 2] for ``target``, then class 1 would be assigned the `absent_score`. - threshold: Threshold value for binary or multi-label probabilities. - multilabel: determines if data is multilabel or not. - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - Example: - >>> from torchmetrics import JaccardIndex + Legacy Example: >>> target = torch.randint(0, 2, (10, 25, 25)) >>> pred = torch.tensor(target) >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] - >>> jaccard = JaccardIndex(num_classes=2) + >>> jaccard = JaccardIndex(task="multiclass", num_classes=2) >>> jaccard(pred, target) tensor(0.9660) """ - is_differentiable: bool = False - higher_is_better: bool = True - full_state_update: bool = False def __new__( cls, - num_classes: int, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - ignore_index: Optional[int] = None, - absent_score: float = 0.0, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - multilabel: bool = False, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, - validate_args: bool = True, - **kwargs: Any, - ) -> Metric: - if task is not None: - kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryJaccardIndex(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassJaccardIndex(num_classes, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - num_classes: int, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, - absent_score: float = 0.0, - threshold: float = 0.5, - multilabel: bool = False, + validate_args: bool = True, **kwargs: Any, - ) -> None: - kwargs["normalize"] = kwargs.get("normalize") - - super().__init__( - num_classes=num_classes, - threshold=threshold, - multilabel=multilabel, - **kwargs, + ) -> Metric: + kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryJaccardIndex(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassJaccardIndex(num_classes, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - self.average = average - self.ignore_index = ignore_index - self.absent_score = absent_score - - def compute(self) -> Tensor: - """Computes intersection over union (IoU)""" - - if self.multilabel: - return torch.stack( - [ - _jaccard_from_confmat( - confmat, - 2, - self.average, - self.ignore_index, - self.absent_score, - )[1] - for confmat in self.confmat - ] - ) - else: - return _jaccard_from_confmat( - self.confmat, - self.num_classes, - self.average, - self.ignore_index, - self.absent_score, - ) diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index d6a47d481d7..aca91fd5a18 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -18,13 +18,8 @@ from typing_extensions import Literal from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix -from torchmetrics.functional.classification.matthews_corrcoef import ( - _matthews_corrcoef_compute, - _matthews_corrcoef_reduce, - _matthews_corrcoef_update, -) +from torchmetrics.functional.classification.matthews_corrcoef import _matthews_corrcoef_reduce from torchmetrics.metric import Metric -from torchmetrics.utilities.prints import rank_zero_warn class BinaryMatthewsCorrCoef(BinaryConfusionMatrix): @@ -220,116 +215,42 @@ def compute(self) -> Tensor: return _matthews_corrcoef_reduce(self.confmat) -class MatthewsCorrCoef(Metric): - r"""Matthews correlation coefficient. +class MatthewsCorrCoef: + r"""Calculates `Matthews correlation coefficient`_ . This metric measures the general correlation or quality of + a classification. - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinaryMatthewsCorrCoef`, :mod:`MulticlassMatthewsCorrCoef` and :mod:`MultilabelMatthewsCorrCoef` for + the specific details of each argument influence and examples. - Calculates `Matthews correlation coefficient`_ that measures the general correlation - or quality of a classification. - - In the binary case it is defined as: - - .. math:: - MCC = \frac{TP*TN - FP*FN}{\sqrt{(TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)}} - - where TP, TN, FP and FN are respectively the true postitives, true negatives, - false positives and false negatives. Also works in the case of multi-label or - multi-class input. - - Note: - This metric produces a multi-dimensional output, so it can not be directly logged. - - Forward accepts - - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - Args: - num_classes: Number of classes in the dataset. - threshold: Threshold value for binary or multi-label probabilites. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - Example: - >>> from torchmetrics import MatthewsCorrCoef + Legacy Example: >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) - >>> matthews_corrcoef = MatthewsCorrCoef(num_classes=2) + >>> matthews_corrcoef = MatthewsCorrCoef(task='binary') >>> matthews_corrcoef(preds, target) tensor(0.5774) """ - is_differentiable: bool = False - higher_is_better: bool = True - full_state_update: bool = False - confmat: Tensor def __new__( cls, - num_classes: int, + task: Literal["binary", "multiclass", "multilabel"] = None, threshold: float = 0.5, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task is not None: - kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryMatthewsCorrCoef(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassMatthewsCorrCoef(num_classes, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - num_classes: int, - threshold: float = 0.5, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - self.num_classes = num_classes - self.threshold = threshold - - self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - confmat = _matthews_corrcoef_update(preds, target, self.num_classes, self.threshold) - self.confmat += confmat - - def compute(self) -> Tensor: - """Computes matthews correlation coefficient.""" - return _matthews_corrcoef_compute(self.confmat) + kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryMatthewsCorrCoef(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassMatthewsCorrCoef(num_classes, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 21d20fd1401..146822bfd38 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -17,20 +17,9 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics.classification.stat_scores import ( - BinaryStatScores, - MulticlassStatScores, - MultilabelStatScores, - StatScores, -) -from torchmetrics.functional.classification.precision_recall import ( - _precision_compute, - _precision_recall_reduce, - _recall_compute, -) +from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores +from torchmetrics.functional.classification.precision_recall import _precision_recall_reduce from torchmetrics.metric import Metric -from torchmetrics.utilities.enums import AverageMethod -from torchmetrics.utilities.prints import rank_zero_warn class BinaryPrecision(BinaryStatScores): @@ -603,376 +592,109 @@ def compute(self) -> Tensor: ) -class Precision(StatScores): - r"""Precision. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes `Precision`_: +class Precision: + r"""Computes `Precision`_: .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and - false positives respecitively. With the use of ``top_k`` parameter, this metric can - generalize to Precision@K. - - The reduction method (how the precision scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. - - Args: - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. - the inputs are treated as if they were ``(N_X, C)``. - From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - top_k: - Number of the highest probability or logit score predictions considered finding the correct label, - relevant only for (multi-dimensional) multi-class inputs. The - default value (``None``) will be interpreted as 1 for these inputs. - Should be left at default (``None``) for all other types of inputs. - - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + false positives respecitively. - Raises: - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinaryPrecision`, :func:`MulticlassPrecision` and :func:`MultilabelPrecision` for the specific details of + each argument influence and examples. - Example: + Legacy Example: >>> import torch - >>> from torchmetrics import Precision >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) - >>> precision = Precision(average='macro', num_classes=3) + >>> precision = Precision(task="multiclass", average='macro', num_classes=3) >>> precision(preds, target) tensor(0.1667) - >>> precision = Precision(average='micro') + >>> precision = Precision(task="multiclass", average='micro', num_classes=3) >>> precision(preds, target) tensor(0.2500) """ - is_differentiable = False - higher_is_better = True - full_state_update: bool = False def __new__( cls, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task is not None: - assert multidim_average is not None - kwargs.update( - dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) - ) - if task == "binary": - return BinaryPrecision(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassPrecision(num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelPrecision(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - num_classes: Optional[int] = None, - threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - **kwargs: Any, - ) -> None: - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) - if "reduce" not in kwargs: - kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average - if "mdmc_reduce" not in kwargs: - kwargs["mdmc_reduce"] = mdmc_average - - super().__init__( - threshold=threshold, - top_k=top_k, - num_classes=num_classes, - multiclass=multiclass, - ignore_index=ignore_index, - **kwargs, + assert multidim_average is not None + kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryPrecision(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassPrecision(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelPrecision(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - self.average = average - - def compute(self) -> Tensor: - """Computes the precision score based on inputs passed in to ``update`` previously. - - Return: - The shape of the returned tensor depends on the ``average`` parameter: - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes - """ - tp, fp, _, fn = self._get_final_stats() - return _precision_compute(tp, fp, fn, self.average, self.mdmc_reduce) - -class Recall(StatScores): - r"""Recall. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes `Recall`_: +class Recall: + r"""Computes `Recall`_: .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and - false negatives respecitively. With the use of ``top_k`` parameter, this metric can - generalize to Recall@K. - - The reduction method (how the recall scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. - - Args: - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - threshold: - Threshold for transforming probability or logit predictions to binary ``(0,1)`` predictions, in the case - of binary or multi-label inputs. Default value of ``0.5`` corresponds to input being probabilities. - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - top_k: - Number of the highest probability or logit score predictions considered finding the correct label, - relevant only for (multi-dimensional) multi-class. The default value (``None``) will be interpreted - as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. - - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + false negatives respecitively. - Raises: - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinaryRecall`, :mod:`MulticlassRecall` and :mod:`MultilabelRecall` for the specific details of + each argument influence and examples. - Example: + Legacy Example: >>> import torch - >>> from torchmetrics import Recall >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) - >>> recall = Recall(average='macro', num_classes=3) + >>> recall = Recall(task="multiclass", average='macro', num_classes=3) >>> recall(preds, target) tensor(0.3333) - >>> recall = Recall(average='micro') + >>> recall = Recall(task="multiclass", average='micro', num_classes=3) >>> recall(preds, target) tensor(0.2500) """ - is_differentiable: bool = False - higher_is_better: bool = True - full_state_update: bool = False def __new__( cls, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task is not None: - assert multidim_average is not None - kwargs.update( - dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) - ) - if task == "binary": - return BinaryRecall(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassRecall(num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelRecall(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - num_classes: Optional[int] = None, - threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - **kwargs: Any, - ) -> None: - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) - if "reduce" not in kwargs: - kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average - if "mdmc_reduce" not in kwargs: - kwargs["mdmc_reduce"] = mdmc_average - - super().__init__( - threshold=threshold, - top_k=top_k, - num_classes=num_classes, - multiclass=multiclass, - ignore_index=ignore_index, - **kwargs, + assert multidim_average is not None + kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryRecall(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassRecall(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelRecall(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - - self.average = average - - def compute(self) -> Tensor: - """Computes the recall score based on inputs passed in to ``update`` previously. - - Return: - The shape of the returned tensor depends on the ``average`` parameter: - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes - """ - tp, fp, _, fn = self._get_final_stats() - return _recall_compute(tp, fp, fn, self.average, self.mdmc_reduce) diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 8615d4cf51a..ad367a41249 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -34,11 +34,8 @@ _multilabel_precision_recall_curve_format, _multilabel_precision_recall_curve_tensor_validation, _multilabel_precision_recall_curve_update, - _precision_recall_curve_compute, - _precision_recall_curve_update, ) from torchmetrics.metric import Metric -from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat @@ -420,40 +417,19 @@ def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], Li return _multilabel_precision_recall_curve_compute(state, self.num_labels, self.thresholds, self.ignore_index) -class PrecisionRecallCurve(Metric): - r"""Precision Recall Curve. +class PrecisionRecallCurve: + r"""Computes the precision-recall curve. The curve consist of multiple pairs of precision and recall values + evaluated at different thresholds, such that the tradeoff between the two values can been seen. - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinaryPrecisionRecallCurve`, :mod:`MulticlassPrecisionRecallCurve` and + :mod:`MultilabelPrecisionRecallCurve` for the specific details of each argument influence and examples. - Computes precision-recall pairs for different thresholds. Works for both binary and multiclass problems. In - the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. - - Forward accepts - - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor - with probabilities, where C is the number of classes. - - - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels - - Args: - num_classes: integer with number of classes for multi-label and multiclass problems. - Should be set to ``None`` for binary problems - pos_label: integer determining the positive class. Default is ``None`` which for binary problem is translated - to 1. For multiclass problems this argument should not be set as we iteratively change it in the range - ``[0, num_classes-1]`` - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - Example (binary case): - >>> from torchmetrics import PrecisionRecallCurve + Legacy Example: >>> pred = torch.tensor([0, 0.1, 0.8, 0.4]) >>> target = torch.tensor([0, 1, 1, 0]) - >>> pr_curve = PrecisionRecallCurve(pos_label=1) + >>> pr_curve = PrecisionRecallCurve(task="binary") >>> precision, recall, thresholds = pr_curve(pred, target) >>> precision tensor([0.6667, 0.5000, 1.0000, 1.0000]) @@ -462,13 +438,12 @@ class PrecisionRecallCurve(Metric): >>> thresholds tensor([0.1000, 0.4000, 0.8000]) - Example (multiclass case): >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> pr_curve = PrecisionRecallCurve(num_classes=5) + >>> pr_curve = PrecisionRecallCurve(task="multiclass", num_classes=5) >>> precision, recall, thresholds = pr_curve(pred, target) >>> precision [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), @@ -479,100 +454,25 @@ class PrecisionRecallCurve(Metric): [tensor(0.7500), tensor(0.7500), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor(0.0500)] """ - is_differentiable: bool = False - higher_is_better: Optional[bool] = None - full_state_update: bool = False - preds: List[Tensor] - target: List[Tensor] - def __new__( cls, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + task: Literal["binary", "multiclass", "multilabel"], thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task is not None: - kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryPrecisionRecallCurve(**kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassPrecisionRecallCurve(num_classes, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelPrecisionRecallCurve(num_labels, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - self.num_classes = num_classes - self.pos_label = pos_label - - self.add_state("preds", default=[], dist_reduce_fx="cat") - self.add_state("target", default=[], dist_reduce_fx="cat") - - rank_zero_warn( - "Metric `PrecisionRecallCurve` will save all targets and predictions in buffer." - " For large datasets this may lead to large memory footprint." - ) - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - preds, target, num_classes, pos_label = _precision_recall_curve_update( - preds, target, self.num_classes, self.pos_label + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryPrecisionRecallCurve(**kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassPrecisionRecallCurve(num_classes, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelPrecisionRecallCurve(num_labels, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - self.preds.append(preds) - self.target.append(target) - self.num_classes = num_classes - self.pos_label = pos_label - - def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - """Compute the precision-recall curve. - - Returns: - 3-element tuple containing - - precision: - tensor where element ``i`` is the precision of predictions with - ``score >= thresholds[i]`` and the last element is 1. - If multiclass, this is a list of such tensors, one for each class. - recall: - tensor where element ``i`` is the recall of predictions with - ``score >= thresholds[i]`` and the last element is 0. - If multiclass, this is a list of such tensors, one for each class. - thresholds: - Thresholds used for computing precision/recall scores - """ - preds = dim_zero_cat(self.preds) - target = dim_zero_cat(self.target) - if not self.num_classes: - raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}") - return _precision_recall_curve_compute(preds, target, self.num_classes, self.pos_label) diff --git a/src/torchmetrics/classification/recall_at_fixed_precision.py b/src/torchmetrics/classification/recall_at_fixed_precision.py index 9bd3665fefd..4c10eef5b18 100644 --- a/src/torchmetrics/classification/recall_at_fixed_precision.py +++ b/src/torchmetrics/classification/recall_at_fixed_precision.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.classification.precision_recall_curve import ( BinaryPrecisionRecallCurve, @@ -29,6 +30,7 @@ _multilabel_recall_at_fixed_precision_arg_compute, _multilabel_recall_at_fixed_precision_arg_validation, ) +from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat @@ -294,3 +296,42 @@ def compute(self) -> Tuple[Tensor, Tensor]: # type: ignore[override] return _multilabel_recall_at_fixed_precision_arg_compute( state, self.num_labels, self.thresholds, self.ignore_index, self.min_precision ) + + +class RecallAtFixedPrecision: + r"""Computes the higest possible recall value given the minimum precision thresholds provided. This is done by + first calculating the precision-recall curve for different thresholds and the find the recall for a given + precision level. + + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinaryRecallAtFixedPrecision`, :func:`MulticlassRecallAtFixedPrecision` and + :func:`MultilabelRecallAtFixedPrecision` for the specific details of each argument influence and examples. + """ + + def __new__( + cls, + task: Literal["binary", "multiclass", "multilabel"], + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + if task == "binary": + return BinaryRecallAtFixedPrecision(min_precision, thresholds, ignore_index, validate_args, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassRecallAtFixedPrecision( + num_classes, min_precision, thresholds, ignore_index, validate_args, **kwargs + ) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelRecallAtFixedPrecision( + num_labels, min_precision, thresholds, ignore_index, validate_args, **kwargs + ) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index d53c1057606..a26b80ee873 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -26,11 +26,8 @@ _binary_roc_compute, _multiclass_roc_compute, _multilabel_roc_compute, - _roc_compute, - _roc_update, ) from torchmetrics.metric import Metric -from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat @@ -305,78 +302,51 @@ def compute(self) -> Tuple[Tensor, Tensor, Tensor]: return _multilabel_roc_compute(state, self.num_labels, self.thresholds, self.ignore_index) -class ROC(Metric): - r"""Receiver Operating Characteristic. +class ROC: + r"""Computes the Receiver Operating Characteristic (ROC). The curve consist of multiple pairs of true positive + rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff + between the two values can be seen. - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinaryROC`, :mod:`MulticlassROC` and :mod:`MultilabelROC` for the specific details of each argument + influence and examples. - Computes the Receiver Operating Characteristic (ROC). Works for both binary, multiclass and multilabel - problems. In the case of multiclass, the values will be calculated based on a one-vs-the-rest approach. - - Forward accepts - - - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass/multilabel) tensor - with probabilities, where C is the number of classes/labels. - - - ``target`` (long tensor): ``(N, ...)`` or ``(N, C, ...)`` with integer labels - - .. note:: - If either the positive class or negative class is completly missing in the target tensor, - the roc values are not well-defined in this case and a tensor of zeros will be returned (either fpr - or tpr depending on what class is missing) together with a warning. - - Args: - num_classes: integer with number of classes for multi-label and multiclass problems. - Should be set to ``None`` for binary problems - pos_label: integer determining the positive class. Default is ``None`` which for binary problem is translated - to 1. For multiclass problems this argument should not be set as we iteratively change it in the range - ``[0,num_classes-1]`` - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - Example (binary case): - >>> from torchmetrics import ROC - >>> pred = torch.tensor([0, 1, 2, 3]) + Legacy Example: + >>> pred = torch.tensor([0.0, 1.0, 2.0, 3.0]) >>> target = torch.tensor([0, 1, 1, 1]) - >>> roc = ROC(pos_label=1) + >>> roc = ROC(task="binary") >>> fpr, tpr, thresholds = roc(pred, target) >>> fpr tensor([0., 0., 0., 0., 1.]) >>> tpr tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) >>> thresholds - tensor([4, 3, 2, 1, 0]) + tensor([1.0000, 0.9526, 0.8808, 0.7311, 0.5000]) - Example (multiclass case): >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05], ... [0.05, 0.05, 0.05, 0.75]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> roc = ROC(num_classes=4) + >>> roc = ROC(task="multiclass", num_classes=4) >>> fpr, tpr, thresholds = roc(pred, target) >>> fpr [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] >>> tpr [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] - >>> thresholds - [tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500])] + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + [tensor([1.0000, 0.7500, 0.0500]), + tensor([1.0000, 0.7500, 0.0500]), + tensor([1.0000, 0.7500, 0.0500]), + tensor([1.0000, 0.7500, 0.0500])] - Example (multilabel case): >>> pred = torch.tensor([[0.8191, 0.3680, 0.1138], ... [0.3584, 0.7576, 0.1183], ... [0.2286, 0.3468, 0.1338], ... [0.8603, 0.0745, 0.1837]]) >>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]]) - >>> roc = ROC(num_classes=3, pos_label=1) + >>> roc = ROC(task='multilabel', num_labels=3) >>> fpr, tpr, thresholds = roc(pred, target) >>> fpr [tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]), @@ -387,99 +357,30 @@ class ROC(Metric): tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 1., 1., 1., 1.])] >>> thresholds - [tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]), - tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]), - tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])] + [tensor([1.0000, 0.8603, 0.8191, 0.3584, 0.2286]), + tensor([1.0000, 0.7576, 0.3680, 0.3468, 0.0745]), + tensor([1.0000, 0.1837, 0.1338, 0.1183, 0.1138])] """ - is_differentiable: bool = False - higher_is_better: Optional[bool] = None - full_state_update: bool = False - preds: List[Tensor] - target: List[Tensor] - def __new__( cls, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + task: Literal["binary", "multiclass", "multilabel"], thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task is not None: - kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) - if task == "binary": - return BinaryROC(**kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - return MulticlassROC(num_classes, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelROC(num_labels, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - self.num_classes = num_classes - self.pos_label = pos_label - - self.add_state("preds", default=[], dist_reduce_fx="cat") - self.add_state("target", default=[], dist_reduce_fx="cat") - - rank_zero_warn( - "Metric `ROC` will save all targets and predictions in buffer." - " For large datasets this may lead to large memory footprint." + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryROC(**kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + return MulticlassROC(num_classes, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelROC(num_labels, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - - Args: - preds: Predictions from model - target: Ground truth values - """ - preds, target, num_classes, pos_label = _roc_update(preds, target, self.num_classes, self.pos_label) - self.preds.append(preds) - self.target.append(target) - self.num_classes = num_classes - self.pos_label = pos_label - - def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - """Compute the receiver operating characteristic. - - Returns: - 3-element tuple containing - - fpr: tensor with false positive rates. - If multiclass, this is a list of such tensors, one for each class. - tpr: tensor with true positive rates. - If multiclass, this is a list of such tensors, one for each class. - thresholds: - thresholds used for computing false- and true-positive rates - """ - preds = dim_zero_cat(self.preds) - target = dim_zero_cat(self.target) - if not self.num_classes: - raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}") - return _roc_compute(preds, target, self.num_classes, self.pos_label) diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 0e618c5fccb..2a572e153f1 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -17,16 +17,9 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics.classification.stat_scores import ( - BinaryStatScores, - MulticlassStatScores, - MultilabelStatScores, - StatScores, -) -from torchmetrics.functional.classification.specificity import _specificity_compute, _specificity_reduce +from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores +from torchmetrics.functional.classification.specificity import _specificity_reduce from torchmetrics.metric import Metric -from torchmetrics.utilities.enums import AverageMethod -from torchmetrics.utilities.prints import rank_zero_warn class BinarySpecificity(BinaryStatScores): @@ -294,189 +287,54 @@ def compute(self) -> Tensor: return _specificity_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) -class Specificity(StatScores): - r"""Specificity. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes `Specificity`_: +class Specificity: + r"""Computes `Specificity`_. .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and - false positives respecitively. With the use of ``top_k`` parameter, this metric can - generalize to Specificity@K. - - The reduction method (how the specificity scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. - - Args: - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tn + fp``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - top_k: - Number of the highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. - - Should be left unset (``None``) for inputs with label predictions. - - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + false positives respecitively. - Raises: - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinarySpecificity`, :mod:`MulticlassSpecificity` and :mod:`MultilabelSpecificity` for the specific + details of each argument influence and examples. - Example: - >>> from torchmetrics import Specificity + Legacy Example: >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) - >>> specificity = Specificity(average='macro', num_classes=3) + >>> specificity = Specificity(task="multiclass", average='macro', num_classes=3) >>> specificity(preds, target) tensor(0.6111) - >>> specificity = Specificity(average='micro') + >>> specificity = Specificity(task="multiclass", average='micro', num_classes=3) >>> specificity(preds, target) tensor(0.6250) """ - is_differentiable: bool = False - higher_is_better: bool = True - full_state_update: bool = False def __new__( cls, - num_classes: Optional[int] = None, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task is not None: - assert multidim_average is not None - kwargs.update( - dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) - ) - if task == "binary": - return BinarySpecificity(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassSpecificity(num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelSpecificity(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, - num_classes: Optional[int] = None, - threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - **kwargs: Any, - ) -> None: - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) - if "reduce" not in kwargs: - kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average - if "mdmc_reduce" not in kwargs: - kwargs["mdmc_reduce"] = mdmc_average - - super().__init__( - threshold=threshold, - top_k=top_k, - num_classes=num_classes, - multiclass=multiclass, - ignore_index=ignore_index, - **kwargs, + assert multidim_average is not None + kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinarySpecificity(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassSpecificity(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelSpecificity(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - - self.average = average - - def compute(self) -> Tensor: - """Computes the specificity score based on inputs passed in to ``update`` previously. - - Return: - The shape of the returned tensor depends on the ``average`` parameter: - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes - """ - tp, fp, tn, fn = self._get_final_stats() - return _specificity_compute(tp, fp, tn, fn, self.average, self.mdmc_reduce) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 86f1e3fc714..4078b878057 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -33,13 +33,9 @@ _multilabel_stat_scores_format, _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, - _stat_scores_compute, - _stat_scores_update, ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat -from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod -from torchmetrics.utilities.prints import rank_zero_warn class _AbstractStatScores(Metric): @@ -493,274 +489,51 @@ def compute(self) -> Tensor: return _multilabel_stat_scores_compute(tp, fp, tn, fn, self.average, self.multidim_average) -class StatScores(Metric): - r"""StatScores. +class StatScores: + r"""Computes the number of true positives, false positives, true negatives, false negatives and the support. - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :mod:`BinaryStatScores`, :mod:`MulticlassStatScores` and :mod:`MultilabelStatScores` for the specific + details of each argument influence and examples. - Computes the number of true positives, false positives, true negatives, false negatives. - Related to `Type I and Type II errors`_ and the `confusion matrix`_. - - The reduction method (how the statistics are aggregated) is controlled by the - ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the - multi-dimensional multi-class case. - - - - Args: - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - - top_k: - Number of the highest probability or logit score predictions considered finding the correct label, - relevant only for (multi-dimensional) multi-class inputs. The default value (``None``) will be interpreted - as 1 for these inputs. Should be left at default (``None``) for all other types of inputs. - - reduce: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Counts the statistics by summing over all ``[sample, class]`` - combinations (globally). Each statistic is represented by a single integer. - - ``'macro'``: Counts the statistics for each class separately (over all samples). - Each statistic is represented by a ``(C,)`` tensor. Requires ``num_classes`` to be set. - - ``'samples'``: Counts the statistics for each sample separately (over all classes). - Each statistic is represented by a ``(N, )`` 1d tensor. - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_reduce``. - - num_classes: Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. - - ignore_index: - Specify a class (label) to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and - ``reduce='macro'``, the class statistics for the ignored class will all be returned - as ``-1``. - - mdmc_reduce: Defines how the multi-dimensional multi-class inputs are handeled. Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then the outputs are concatenated together. In each - sample the extra axes ``...`` are flattened to become the sub-sample axis, and - statistics for each sample are computed by treating the sub-sample axis as the - ``N`` axis for that sample. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs are - flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. - - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - - Raises: - ValueError: - If ``reduce`` is none of ``"micro"``, ``"macro"`` or ``"samples"``. - ValueError: - If ``mdmc_reduce`` is none of ``None``, ``"samplewise"``, ``"global"``. - ValueError: - If ``reduce`` is set to ``"macro"`` and ``num_classes`` is not provided. - ValueError: - If ``num_classes`` is set - and ``ignore_index`` is not in the range ``0`` <= ``ignore_index`` < ``num_classes``. - - Example: - >>> from torchmetrics.classification import StatScores + Legacy Example: >>> preds = torch.tensor([1, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) - >>> stat_scores = StatScores(reduce='macro', num_classes=3) + >>> stat_scores = StatScores(task="multiclass", num_classes=3, average='micro') + >>> stat_scores(preds, target) + tensor([2, 2, 6, 2, 4]) + >>> stat_scores = StatScores(task="multiclass", num_classes=3, average=None) >>> stat_scores(preds, target) tensor([[0, 1, 2, 1, 1], [1, 1, 1, 1, 2], [1, 0, 3, 0, 1]]) - >>> stat_scores = StatScores(reduce='micro') - >>> stat_scores(preds, target) - tensor([2, 2, 6, 2, 4]) """ - is_differentiable: bool = False - higher_is_better: Optional[bool] = None - full_state_update: bool = False - # TODO: canot be used because if scripting - # tp: Union[Tensor, List[Tensor]] - # fp: Union[Tensor, List[Tensor]] - # tn: Union[Tensor, List[Tensor]] - # fn: Union[Tensor, List[Tensor]] def __new__( cls, - num_classes: Optional[int] = None, - threshold: float = 0.5, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, - num_labels: Optional[int] = None, - multidim_average: Optional[Literal["global", "samplewise"]] = "global", - validate_args: bool = True, - **kwargs: Any, - ) -> Metric: - if task is not None: - assert multidim_average is not None - kwargs.update( - dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) - ) - if task == "binary": - return BinaryStatScores(threshold, **kwargs) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return MulticlassStatScores(num_classes, top_k, average, **kwargs) - if task == "multilabel": - assert isinstance(num_labels, int) - return MultilabelStatScores(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'Binary*'`, `'Multiclass*', `'Multilabel*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - return super().__new__(cls) - - def __init__( - self, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - top_k: Optional[int] = None, - reduce: str = "micro", num_classes: Optional[int] = None, - ignore_index: Optional[int] = None, - mdmc_reduce: Optional[str] = None, - multiclass: Optional[bool] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - self.reduce = reduce - self.mdmc_reduce = mdmc_reduce - self.num_classes = num_classes - self.threshold = threshold - self.multiclass = multiclass - self.ignore_index = ignore_index - self.top_k = top_k - - if reduce not in ["micro", "macro", "samples"]: - raise ValueError(f"The `reduce` {reduce} is not valid.") - - if mdmc_reduce not in [None, "samplewise", "global"]: - raise ValueError(f"The `mdmc_reduce` {mdmc_reduce} is not valid.") - - if reduce == "macro" and (not num_classes or num_classes < 1): - raise ValueError("When you set `reduce` as 'macro', you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - default: Callable = lambda: [] - reduce_fn: Optional[str] = "cat" - if mdmc_reduce != "samplewise" and reduce != "samples": - if reduce == "micro": - zeros_shape = [] - elif reduce == "macro": - zeros_shape = [num_classes] - else: - raise ValueError(f'Wrong reduce="{reduce}"') - default = lambda: torch.zeros(zeros_shape, dtype=torch.long) - reduce_fn = "sum" - - for s in ("tp", "fp", "tn", "fn"): - self.add_state(s, default=default(), dist_reduce_fx=reduce_fn) - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - """Update state with predictions and targets. - - Args: - preds: Predictions from model (probabilities, logits or labels) - target: Ground truth values - """ - tp, fp, tn, fn = _stat_scores_update( - preds, - target, - reduce=self.reduce, - mdmc_reduce=self.mdmc_reduce, - threshold=self.threshold, - num_classes=self.num_classes, - top_k=self.top_k, - multiclass=self.multiclass, - ignore_index=self.ignore_index, + ) -> Metric: + assert multidim_average is not None + kwargs.update(dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryStatScores(threshold, **kwargs) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return MulticlassStatScores(num_classes, top_k, average, **kwargs) + if task == "multilabel": + assert isinstance(num_labels, int) + return MultilabelStatScores(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - - # Update states - if self.reduce != AverageMethod.SAMPLES and self.mdmc_reduce != MDMCAverageMethod.SAMPLEWISE: - self.tp += tp - self.fp += fp - self.tn += tn - self.fn += fn - else: - self.tp.append(tp) - self.fp.append(fp) - self.tn.append(tn) - self.fn.append(fn) - - def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - """Performs concatenation on the stat scores if neccesary, before passing them to a compute function.""" - tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp - fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp - tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn - fn = torch.cat(self.fn) if isinstance(self.fn, list) else self.fn - return tp, fp, tn, fn - - def compute(self) -> Tensor: - """Computes the stat scores based on inputs passed in to ``update`` previously. - - Return: - The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds - to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The - shape depends on the ``reduce`` and ``mdmc_reduce`` (in case of multi-dimensional - multi-class data) parameters: - - - If the data is not multi-dimensional multi-class, then - - - If ``reduce='micro'``, the shape will be ``(5, )`` - - If ``reduce='macro'``, the shape will be ``(C, 5)``, - where ``C`` stands for the number of classes - - If ``reduce='samples'``, the shape will be ``(N, 5)``, where ``N`` stands for - the number of samples - - - If the data is multi-dimensional multi-class and ``mdmc_reduce='global'``, then - - - If ``reduce='micro'``, the shape will be ``(5, )`` - - If ``reduce='macro'``, the shape will be ``(C, 5)`` - - If ``reduce='samples'``, the shape will be ``(N*X, 5)``, where ``X`` stands for - the product of sizes of all "extra" dimensions of the data (i.e. all dimensions - except for ``C`` and ``N``) - - - If the data is multi-dimensional multi-class and ``mdmc_reduce='samplewise'``, then - - - If ``reduce='micro'``, the shape will be ``(N, 5)`` - - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` - - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` - """ - tp, fp, tn, fn = self._get_final_stats() - return _stat_scores_compute(tp, fp, tn, fn) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 339ac7de805..c3e2e1b25c0 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -83,24 +83,30 @@ class name as key for the output dict. Example (input as list): >>> import torch >>> from pprint import pprint - >>> from torchmetrics import MetricCollection, Accuracy, Precision, Recall, MeanSquaredError + >>> from torchmetrics import MetricCollection, MeanSquaredError + >>> from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall >>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) >>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) - >>> metrics = MetricCollection([Accuracy(), - ... Precision(num_classes=3, average='macro'), - ... Recall(num_classes=3, average='macro')]) - >>> metrics(preds, target) - {'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} + >>> metrics = MetricCollection([MulticlassAccuracy(num_classes=3, average='micro'), + ... MulticlassPrecision(num_classes=3, average='macro'), + ... MulticlassRecall(num_classes=3, average='macro')]) + >>> metrics(preds, target) # doctest: +NORMALIZE_WHITESPACE + {'MulticlassAccuracy': tensor(0.1250), + 'MulticlassPrecision': tensor(0.0667), + 'MulticlassRecall': tensor(0.1111)} Example (input as arguments): - >>> metrics = MetricCollection(Accuracy(), Precision(num_classes=3, average='macro'), - ... Recall(num_classes=3, average='macro')) - >>> metrics(preds, target) - {'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} + >>> metrics = MetricCollection(MulticlassAccuracy(num_classes=3, average='micro'), + ... MulticlassPrecision(num_classes=3, average='macro'), + ... MulticlassRecall(num_classes=3, average='macro')) + >>> metrics(preds, target) # doctest: +NORMALIZE_WHITESPACE + {'MulticlassAccuracy': tensor(0.1250), + 'MulticlassPrecision': tensor(0.0667), + 'MulticlassRecall': tensor(0.1111)} Example (input as dict): - >>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'), - ... 'macro_recall': Recall(num_classes=3, average='macro')}) + >>> metrics = MetricCollection({'micro_recall': MulticlassRecall(num_classes=3, average='micro'), + ... 'macro_recall': MulticlassRecall(num_classes=3, average='macro')}) >>> same_metric = metrics.clone() >>> pprint(metrics(preds, target)) {'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)} @@ -109,33 +115,33 @@ class name as key for the output dict. Example (specification of compute groups): >>> metrics = MetricCollection( - ... Recall(num_classes=3, average='macro'), - ... Precision(num_classes=3, average='macro'), + ... MulticlassRecall(num_classes=3, average='macro'), + ... MulticlassPrecision(num_classes=3, average='macro'), ... MeanSquaredError(), - ... compute_groups=[['Recall', 'Precision'], ['MeanSquaredError']] + ... compute_groups=[['MulticlassRecall', 'MulticlassPrecision'], ['MeanSquaredError']] ... ) >>> metrics.update(preds, target) >>> pprint(metrics.compute()) - {'MeanSquaredError': tensor(2.3750), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} + {'MeanSquaredError': tensor(2.3750), 'MulticlassPrecision': tensor(0.0667), 'MulticlassRecall': tensor(0.1111)} >>> pprint(metrics.compute_groups) - {0: ['Recall', 'Precision'], 1: ['MeanSquaredError']} + {0: ['MulticlassRecall', 'MulticlassPrecision'], 1: ['MeanSquaredError']} Example (nested metric collections): >>> metrics = MetricCollection([ ... MetricCollection([ - ... Accuracy(num_classes=3, average='macro'), - ... Precision(num_classes=3, average='macro') + ... MulticlassAccuracy(num_classes=3, average='macro'), + ... MulticlassPrecision(num_classes=3, average='macro') ... ], postfix='_macro'), ... MetricCollection([ - ... Accuracy(num_classes=3, average='micro'), - ... Precision(num_classes=3, average='micro') + ... MulticlassAccuracy(num_classes=3, average='micro'), + ... MulticlassPrecision(num_classes=3, average='micro') ... ], postfix='_micro'), ... ], prefix='valmetrics/') >>> pprint(metrics(preds, target)) # doctest: +NORMALIZE_WHITESPACE - {'valmetrics/Accuracy_macro': tensor(0.1111), - 'valmetrics/Accuracy_micro': tensor(0.1250), - 'valmetrics/Precision_macro': tensor(0.0667), - 'valmetrics/Precision_micro': tensor(0.1250)} + {'valmetrics/MulticlassAccuracy_macro': tensor(0.1111), + 'valmetrics/MulticlassAccuracy_micro': tensor(0.1250), + 'valmetrics/MulticlassPrecision_macro': tensor(0.0667), + 'valmetrics/MulticlassPrecision_micro': tensor(0.1250)} """ _groups: Dict[int, List[str]] diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index c0ba6fbf262..c8fce9ee422 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -27,7 +27,7 @@ from torchmetrics.functional.classification.hinge import hinge_loss from torchmetrics.functional.classification.jaccard import jaccard_index from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef -from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall +from torchmetrics.functional.classification.precision_recall import precision, recall from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve from torchmetrics.functional.classification.roc import roc from torchmetrics.functional.classification.specificity import specificity @@ -145,7 +145,6 @@ "perplexity", "pit_permutate", "precision", - "precision_recall", "precision_recall_curve", "peak_signal_noise_ratio", "r2_score", diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 1a795eae5a8..6aa93dd9261 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -92,7 +92,6 @@ multilabel_precision, multilabel_recall, precision, - precision_recall, recall, ) from torchmetrics.functional.classification.precision_recall_curve import ( # noqa: F401 diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index 444dbd6aeb9..445d50f2953 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch -from torch import Tensor, tensor +from torch import Tensor from typing_extensions import Literal from torchmetrics.functional.classification.stat_scores import ( @@ -30,13 +30,8 @@ _multilabel_stat_scores_format, _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, - _reduce_stat_scores, - _stat_scores_update, ) -from torchmetrics.utilities.checks import _check_classification_inputs, _input_format_classification, _input_squeeze from torchmetrics.utilities.compute import _safe_divide -from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod -from torchmetrics.utilities.prints import rank_zero_warn def _accuracy_reduce( @@ -386,437 +381,57 @@ def multilabel_accuracy( return _accuracy_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True) -def _check_subset_validity(mode: DataType) -> bool: - """Checks input mode is valid.""" - return mode in (DataType.MULTILABEL, DataType.MULTIDIM_MULTICLASS) - - -def _mode( - preds: Tensor, - target: Tensor, - threshold: float, - top_k: Optional[int], - num_classes: Optional[int], - multiclass: Optional[bool], - ignore_index: Optional[int] = None, -) -> DataType: - """Finds the mode of the input tensors. - - Args: - preds: Predicted tensor - target: Ground truth tensor - threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the - case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - top_k: Number of the highest probability or logit score predictions considered finding the correct label, - relevant only for (multi-dimensional) multi-class inputs. - num_classes: Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - Example: - >>> target = torch.tensor([0, 1, 2, 3]) - >>> preds = torch.tensor([0, 2, 1, 3]) - >>> _mode(preds, target, 0.5, None, None, None) - - """ - - mode = _check_classification_inputs( - preds, - target, - threshold=threshold, - top_k=top_k, - num_classes=num_classes, - multiclass=multiclass, - ignore_index=ignore_index, - ) - return mode - - -def _accuracy_update( - preds: Tensor, - target: Tensor, - reduce: Optional[str], - mdmc_reduce: Optional[str], - threshold: float, - num_classes: Optional[int], - top_k: Optional[int], - multiclass: Optional[bool], - ignore_index: Optional[int], - mode: DataType, -) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - """Updates and returns stat scores (true positives, false positives, true negatives, false negatives) required - to compute accuracy. - - Args: - preds: Predicted tensor - target: Ground truth tensor - reduce: Defines the reduction that is applied. - mdmc_reduce: Defines how the multi-dimensional multi-class inputs are handled. - threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in - the case of binary or multi-label inputs. - num_classes: Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - top_k: Number of the highest probability or logit score predictions considered finding the correct label, - relevant only for (multi-dimensional) multi-class inputs. - multiclass: Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - mode: Mode of the input tensors. - """ - - if mode == DataType.MULTILABEL and top_k: - raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") - preds, target = _input_squeeze(preds, target) - tp, fp, tn, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce=mdmc_reduce, - threshold=threshold, - num_classes=num_classes, - top_k=top_k, - multiclass=multiclass, - ignore_index=ignore_index, - mode=mode, - ) - return tp, fp, tn, fn - - -def _accuracy_compute( - tp: Tensor, - fp: Tensor, - tn: Tensor, - fn: Tensor, - average: Optional[str], - mdmc_average: Optional[str], - mode: DataType, -) -> Tensor: - """Computes accuracy from stat scores: true positives, false positives, true negatives, false negatives. - - Args: - tp: True positives - fp: False positives - tn: True negatives - fn: False negatives - average: Defines the reduction that is applied. - mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). - mode: Mode of the input tensors - - Example: - >>> preds = torch.tensor([0, 2, 1, 3]) - >>> target = torch.tensor([0, 1, 2, 3]) - >>> threshold = 0.5 - >>> reduce = average = 'micro' - >>> mdmc_average = 'global' - >>> mode = _mode(preds, target, threshold, top_k=None, num_classes=None, multiclass=None) - >>> tp, fp, tn, fn = _accuracy_update( - ... preds, - ... target, - ... reduce, - ... mdmc_average, - ... threshold=0.5, - ... num_classes=None, - ... top_k=None, - ... multiclass=None, - ... ignore_index=None, - ... mode=mode) - >>> _accuracy_compute(tp, fp, tn, fn, average, mdmc_average, mode) - tensor(0.5000) - - >>> target = torch.tensor([0, 1, 2]) - >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) - >>> top_k, threshold = 2, 0.5 - >>> reduce = average = 'micro' - >>> mdmc_average = 'global' - >>> mode = _mode(preds, target, threshold, top_k, num_classes=None, multiclass=None) - >>> tp, fp, tn, fn = _accuracy_update(preds, target, reduce, mdmc_average, threshold, - ... num_classes=None, top_k=top_k, multiclass=None, ignore_index=None, mode=mode) - >>> _accuracy_compute(tp, fp, tn, fn, average, mdmc_average, mode) - tensor(0.6667) - """ - - simple_average = [AverageMethod.MICRO, AverageMethod.SAMPLES] - if (mode == DataType.BINARY and average in simple_average) or mode == DataType.MULTILABEL: - numerator = tp + tn - denominator = tp + tn + fp + fn - else: - numerator = tp.clone() - denominator = tp + fn - - if mdmc_average != MDMCAverageMethod.SAMPLEWISE: - if average == AverageMethod.MACRO: - cond = tp + fp + fn == 0 - numerator = numerator[~cond] - denominator = denominator[~cond] - - if average == AverageMethod.NONE: - # a class is not present if there exists no TPs, no FPs, and no FNs - meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu() - numerator[meaningless_indeces, ...] = -1 - denominator[meaningless_indeces, ...] = -1 - - return _reduce_stat_scores( - numerator=numerator, - denominator=denominator, - weights=None if average != AverageMethod.WEIGHTED else tp + fn, - average=average, - mdmc_average=mdmc_average, - ) - - -def _subset_accuracy_update( - preds: Tensor, - target: Tensor, - threshold: float, - top_k: Optional[int], - ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor]: - """Updates and returns variables required to compute subset accuracy. - - Args: - preds: Predicted tensor - target: Ground truth tensor - threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of ``0.5`` corresponds to input being probabilities. - top_k: Number of the highest probability or logit score predictions considered finding the correct label, - relevant only for (multi-dimensional) multi-class inputs. - """ - - preds, target = _input_squeeze(preds, target) - preds, target, mode = _input_format_classification( - preds, target, threshold=threshold, top_k=top_k, ignore_index=ignore_index - ) - - if mode == DataType.MULTILABEL and top_k: - raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") - - if mode == DataType.MULTILABEL: - correct = (preds == target).all(dim=1).sum() - total = tensor(target.shape[0], device=target.device) - elif mode == DataType.MULTICLASS: - correct = (preds * target).sum() - total = target.sum() - elif mode == DataType.MULTIDIM_MULTICLASS: - sample_correct = (preds * target).sum(dim=(1, 2)) - correct = (sample_correct == target.shape[2]).sum() - total = tensor(target.shape[0], device=target.device) - else: - correct, total = tensor(0), tensor(0) - - return correct, total - - -def _subset_accuracy_compute(correct: Tensor, total: Tensor) -> Tensor: - """Computes subset accuracy from number of correct observations and total number of observations. - - Args: - correct: Number of correct observations - total: Number of observations - """ - - return correct.float() / total - - def accuracy( preds: Tensor, target: Tensor, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = "global", + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - top_k: Optional[int] = None, - subset_accuracy: bool = False, num_classes: Optional[int] = None, - multiclass: Optional[bool] = None, - ignore_index: Optional[int] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Accuracy. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes `Accuracy`_ + r"""Computes `Accuracy`_ .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) - Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a - tensor of predictions. - - For multi-class and multi-dimensional multi-class data with probability or logits predictions, the - parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the - top-K highest probability or logits items are considered to find the correct label. - - For multi-label and multi-dimensional multi-class inputs, this metric computes the "global" - accuracy by default, which counts all labels or sub-samples separately. This can be - changed to subset accuracy (which requires all labels or sub-samples in the sample to - be correctly predicted) by setting ``subset_accuracy=True``. - - Args: - preds: Predictions from model (probabilities, logits or labels) - target: Ground truth labels - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, - the value for the class will be ``nan``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions. - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_accuracy`, :func:`multiclass_accuracy` and :func:`multilabel_accuracy` for the specific details of + each argument influence and examples. - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - top_k: - Number of the highest probability or logit score predictions considered finding the correct label, - relevant only for (multi-dimensional) multi-class inputs. The - default value (``None``) will be interpreted as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - subset_accuracy: - Whether to compute subset accuracy for multi-label and multi-dimensional - multi-class inputs (has no effect for other input types). - - - For multi-label inputs, if the parameter is set to ``True``, then all labels for - each sample must be correctly predicted for the sample to count as correct. If it - is set to ``False``, then all labels are counted separately - this is equivalent to - flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). - - - For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all - sub-sample (on the extra axis) must be correct for the sample to be counted as correct. - If it is set to ``False``, then all sub-samples are counter separately - this is equivalent, - in the case of label predictions, to flattening the inputs beforehand (i.e. - ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter - still applies in both cases, if set. - - Raises: - ValueError: - If ``top_k`` parameter is set for ``multi-label`` inputs. - ValueError: - If ``average`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"``, ``None``. - ValueError: - If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. - ValueError: - If ``average`` is set but ``num_classes`` is not provided. - ValueError: - If ``num_classes`` is set - and ``ignore_index`` is not in the range ``[0, num_classes)``. - ValueError: - If ``top_k`` is not an ``integer`` larger than ``0``. - - Example: + Legacy Example: >>> import torch - >>> from torchmetrics.functional import accuracy >>> target = torch.tensor([0, 1, 2, 3]) >>> preds = torch.tensor([0, 2, 1, 3]) - >>> accuracy(preds, target) + >>> accuracy(preds, target, task="multiclass", num_classes=4) tensor(0.5000) >>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) - >>> accuracy(preds, target, top_k=2) + >>> accuracy(preds, target, task="multiclass", num_classes=3, top_k=2) tensor(0.6667) """ - if task is not None: - assert multidim_average is not None - if task == "binary": - return binary_accuracy(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return multiclass_accuracy( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args - ) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_accuracy( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args - ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + assert multidim_average is not None + if task == "binary": + return binary_accuracy(preds, target, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_accuracy( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_accuracy( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): - raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") - - allowed_mdmc_average = [None, "samplewise", "global"] - if mdmc_average not in allowed_mdmc_average: - raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") - - if num_classes and ignore_index is not None and (not ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - if top_k is not None and (not isinstance(top_k, int) or top_k <= 0): - raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}") - - preds, target = _input_squeeze(preds, target) - mode = _mode(preds, target, threshold, top_k, num_classes, multiclass, ignore_index) - reduce = "macro" if average in ["weighted", "none", None] else average - - if subset_accuracy and _check_subset_validity(mode): - correct, total = _subset_accuracy_update(preds, target, threshold, top_k, ignore_index) - return _subset_accuracy_compute(correct, total) - tp, fp, tn, fn = _accuracy_update( - preds, target, reduce, mdmc_average, threshold, num_classes, top_k, multiclass, ignore_index, mode + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - return _accuracy_compute(tp, fp, tn, fn, average, mdmc_average, mode) diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index 1b064c35f02..989eabd0ef3 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -11,8 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import warnings -from typing import List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import Tensor, tensor @@ -36,12 +35,9 @@ _binary_roc_compute, _multiclass_roc_compute, _multilabel_roc_compute, - roc, ) -from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.compute import _auc_compute_without_check, _safe_divide from torchmetrics.utilities.data import _bincount -from torchmetrics.utilities.enums import AverageMethod, DataType from torchmetrics.utilities.prints import rank_zero_warn @@ -418,275 +414,50 @@ def multilabel_auroc( return _multilabel_auroc_compute(state, num_labels, average, thresholds, ignore_index) -def _auroc_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, DataType]: - """Updates and returns variables required to compute Area Under the Receiver Operating Characteristic Curve. - Validates the inputs and returns the mode of the inputs. - - Args: - preds: Predicted tensor - target: Ground truth tensor - """ - - # use _input_format_classification for validating the input and get the mode of data - _, _, mode = _input_format_classification(preds, target) - - if mode == "multi class multi dim": - n_classes = preds.shape[1] - preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) - target = target.flatten() - if mode == "multi-label" and preds.ndim > 2: - n_classes = preds.shape[1] - preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) - target = target.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1) - - return preds, target, mode - - -def _auroc_compute( - preds: Tensor, - target: Tensor, - mode: DataType, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - average: Optional[str] = "macro", - max_fpr: Optional[float] = None, - sample_weights: Optional[Sequence] = None, -) -> Tensor: - """Computes Area Under the Receiver Operating Characteristic Curve. - - Args: - preds: predictions from model (logits or probabilities) - target: Ground truth labels - mode: 'multi class multi dim' or 'multi-label' or 'binary' - num_classes: integer with number of classes for multi-label and multiclass problems. - Should be set to ``None`` for binary problems - pos_label: integer determining the positive class. - Should be set to ``None`` for binary problems - average: Defines the reduction that is applied to the output: - max_fpr: If not ``None``, calculates standardized partial AUC over the - range ``[0, max_fpr]``. Should be a float between 0 and 1. - sample_weights: sample weights for each data point - - Example: - >>> # binary case - >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) - >>> target = torch.tensor([0, 0, 1, 1, 1]) - >>> preds, target, mode = _auroc_update(preds, target) - >>> _auroc_compute(preds, target, mode, pos_label=1) - tensor(0.5000) - - >>> # multiclass case - >>> preds = torch.tensor([[0.90, 0.05, 0.05], - ... [0.05, 0.90, 0.05], - ... [0.05, 0.05, 0.90], - ... [0.85, 0.05, 0.10], - ... [0.10, 0.10, 0.80]]) - >>> target = torch.tensor([0, 1, 1, 2, 2]) - >>> preds, target, mode = _auroc_update(preds, target) - >>> _auroc_compute(preds, target, mode, num_classes=3) - tensor(0.7778) - """ - - # binary mode override num_classes - if mode == DataType.BINARY: - num_classes = 1 - - # check max_fpr parameter - if max_fpr is not None: - if not isinstance(max_fpr, float) and 0 < max_fpr <= 1: - raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}") - - # max_fpr parameter is only support for binary - if mode != DataType.BINARY: - raise ValueError( - "Partial AUC computation not available in multilabel/multiclass setting," - f" 'max_fpr' must be set to `None`, received `{max_fpr}`." - ) - - # calculate fpr, tpr - if mode == DataType.MULTILABEL: - if average == AverageMethod.MICRO: - fpr, tpr, _ = roc(preds.flatten(), target.flatten(), 1, pos_label, sample_weights) - elif num_classes: - # for multilabel we iteratively evaluate roc in a binary fashion - output = [ - roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights) - for i in range(num_classes) - ] - fpr = [o[0] for o in output] - tpr = [o[1] for o in output] - else: - raise ValueError("Detected input to be `multilabel` but you did not provide `num_classes` argument") - else: - if mode != DataType.BINARY: - if num_classes is None: - raise ValueError("Detected input to `multiclass` but you did not provide `num_classes` argument") - if average == AverageMethod.WEIGHTED and len(torch.unique(target)) < num_classes: - # If one or more classes has 0 observations, we should exclude them, as its weight will be 0 - target_bool_mat = torch.zeros((len(target), num_classes), dtype=bool, device=target.device) - target_bool_mat[torch.arange(len(target)), target.long()] = 1 - class_observed = target_bool_mat.sum(axis=0) > 0 - for c in range(num_classes): - if not class_observed[c]: - warnings.warn(f"Class {c} had 0 observations, omitted from AUROC calculation", UserWarning) - preds = preds[:, class_observed] - target = target_bool_mat[:, class_observed] - target = torch.where(target)[1] - num_classes = class_observed.sum() - if num_classes == 1: - raise ValueError("Found 1 non-empty class in `multiclass` AUROC calculation") - fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights) - - # calculate standard roc auc score - if max_fpr is None or max_fpr == 1: - if mode == DataType.MULTILABEL and average == AverageMethod.MICRO: - pass - elif num_classes != 1: - # calculate auc scores per class - auc_scores = [_auc_compute_without_check(x, y, 1.0) for x, y in zip(fpr, tpr)] - - # calculate average - if average == AverageMethod.NONE: - return tensor(auc_scores) - if average == AverageMethod.MACRO: - return torch.mean(torch.stack(auc_scores)) - if average == AverageMethod.WEIGHTED: - if mode == DataType.MULTILABEL: - support = torch.sum(target, dim=0) - else: - support = _bincount(target.flatten(), minlength=num_classes) - return torch.sum(torch.stack(auc_scores) * support / support.sum()) - - allowed_average = (AverageMethod.NONE.value, AverageMethod.MACRO.value, AverageMethod.WEIGHTED.value) - raise ValueError( - f"Argument `average` expected to be one of the following: {allowed_average} but got {average}" - ) - - return _auc_compute_without_check(fpr, tpr, 1.0) - - _device = fpr.device if isinstance(fpr, Tensor) else fpr[0].device - max_area: Tensor = tensor(max_fpr, device=_device) - # Add a single point at max_fpr and interpolate its tpr value - stop = torch.bucketize(max_area, fpr, out_int32=True, right=True) - weight = (max_area - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1]) - interp_tpr: Tensor = torch.lerp(tpr[stop - 1], tpr[stop], weight) - tpr = torch.cat([tpr[:stop], interp_tpr.view(1)]) - fpr = torch.cat([fpr[:stop], max_area.view(1)]) - - # Compute partial AUC - partial_auc = _auc_compute_without_check(fpr, tpr, 1.0) - - # McClish correction: standardize result to be 0.5 if non-discriminant and 1 if maximal - min_area: Tensor = 0.5 * max_area**2 - return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area)) - - def auroc( preds: Tensor, target: Tensor, + task: Literal["binary", "multiclass", "multilabel"], + thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, - pos_label: Optional[int] = None, + num_labels: Optional[int] = None, average: Optional[Literal["macro", "weighted", "none"]] = "macro", max_fpr: Optional[float] = None, - sample_weights: Optional[Sequence] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, - thresholds: Optional[Union[int, List[float], Tensor]] = None, - num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - r"""Area Under the Receiver Operating Characteristic Curve. + r"""Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_). The AUROC score summarizes the + ROC curve into an single number that describes the performance of a model for multiple thresholds at the same + time. Notably, an AUROC score of 1 is a perfect score and an AUROC score of 0.5 corresponds to random guessing. - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_auroc`, :func:`multiclass_auroc` and :func:`multilabel_auroc` for the specific details of + each argument influence and examples. - Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) - - For non-binary input, if the ``preds`` and ``target`` tensor have the same - size the input will be interpretated as multilabel and if ``preds`` have one - dimension more than the ``target`` tensor the input will be interpretated as - multiclass. - - .. note:: - If either the positive class or negative class is completly missing in the target tensor, - the auroc score is meaningless in this case and a score of 0 will be returned together - with a warning. - - Args: - preds: predictions from model (logits or probabilities) - target: Ground truth labels - num_classes: integer with number of classes for multi-label and multiclass problems. - Should be set to ``None`` for binary problems - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translate to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - average: - - - ``'macro'`` computes metric for each class and uniformly averages them - - ``'weighted'`` computes metric for each class and does a weighted-average, - where each class is weighted by their support (accounts for class imbalance) - - ``None`` computes and returns the metric per class - - max_fpr: - If not ``None``, calculates standardized partial AUC over the - range ``[0, max_fpr]``. Should be a float between 0 and 1. - sample_weights: sample weights for each data point - - Raises: - ValueError: - If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``. - RuntimeError: - If ``PyTorch version`` is below 1.6 since max_fpr requires ``torch.bucketize`` - which is not available below 1.6. - ValueError: - If ``max_fpr`` is not set to ``None`` and the mode is ``not binary`` - since partial AUC computation is not available in multilabel/multiclass. - ValueError: - If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``. - - Example (binary case): - >>> from torchmetrics.functional import auroc + Legacy Example: >>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34]) >>> target = torch.tensor([0, 0, 1, 1, 1]) - >>> auroc(preds, target, pos_label=1) + >>> auroc(preds, target, task='binary') tensor(0.5000) - Example (multiclass case): >>> preds = torch.tensor([[0.90, 0.05, 0.05], ... [0.05, 0.90, 0.05], ... [0.05, 0.05, 0.90], ... [0.85, 0.05, 0.10], ... [0.10, 0.10, 0.80]]) >>> target = torch.tensor([0, 1, 1, 2, 2]) - >>> auroc(preds, target, num_classes=3) + >>> auroc(preds, target, task='multiclass', num_classes=3) tensor(0.7778) """ - if task is not None: - if task == "binary": - return binary_auroc(preds, target, max_fpr, thresholds, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - return multiclass_auroc(preds, target, num_classes, average, thresholds, ignore_index, validate_args) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_auroc(preds, target, num_labels, average, thresholds, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - - preds, target, mode = _auroc_update(preds, target) - return _auroc_compute(preds, target, mode, num_classes, pos_label, average, max_fpr, sample_weights) + if task == "binary": + return binary_auroc(preds, target, max_fpr, thresholds, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_auroc(preds, target, num_classes, average, thresholds, ignore_index, validate_args) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_auroc(preds, target, num_labels, average, thresholds, ignore_index, validate_args) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 259454be6ec..792e1e3600f 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import warnings from typing import List, Optional, Tuple, Union import torch @@ -34,8 +33,6 @@ _multilabel_precision_recall_curve_format, _multilabel_precision_recall_curve_tensor_validation, _multilabel_precision_recall_curve_update, - _precision_recall_curve_compute, - _precision_recall_curve_update, ) from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.data import _bincount @@ -188,9 +185,9 @@ def multiclass_average_precision( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Computes the average precision (AP) score for binary tasks. The AP score summarizes a precision-recall curve - as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold - as weight: + r"""Computes the average precision (AP) score for multiclass tasks. The AP score summarizes a precision-recall + curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous + threshold as weight: .. math:: AP = \sum{n} (R_n - R_{n-1}) P_n @@ -317,9 +314,9 @@ def multilabel_average_precision( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Computes the average precision (AP) score for binary tasks. The AP score summarizes a precision-recall curve - as an weighted mean of precisions at each threshold, with the difference in recall from the previous threshold - as weight: + r"""Computes the average precision (AP) score for multilabel tasks. The AP score summarizes a precision-recall + curve as an weighted mean of precisions at each threshold, with the difference in recall from the previous + threshold as weight: .. math:: AP = \sum{n} (R_n - R_{n-1}) P_n @@ -401,241 +398,56 @@ def multilabel_average_precision( return _multilabel_average_precision_compute(state, num_labels, average, thresholds, ignore_index) -def _average_precision_update( - preds: Tensor, - target: Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - average: Optional[str] = "macro", -) -> Tuple[Tensor, Tensor, int, Optional[int]]: - """Format the predictions and target based on the ``num_classes``, ``pos_label`` and ``average`` parameter. - - Args: - preds: predictions from model (logits or probabilities) - target: ground truth values - num_classes: integer with number of classes. - pos_label: integer determining the positive class. Default is ``None`` which for binary problem is translated - to 1. For multiclass problems this argument should not be set as we iteratively change it in the - range ``[0, num_classes-1]`` - average: reduction method for multi-class or multi-label problems - """ - preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) - if average == "micro" and preds.ndim != target.ndim: - raise ValueError("Cannot use `micro` average with multi-class input") - - return preds, target, num_classes, pos_label - - -def _average_precision_compute( - preds: Tensor, - target: Tensor, - num_classes: int, - pos_label: Optional[int] = None, - average: Optional[str] = "macro", -) -> Union[List[Tensor], Tensor]: - """Computes the average precision score. - - Args: - preds: predictions from model (logits or probabilities) - target: ground truth values - num_classes: integer with number of classes. - pos_label: integer determining the positive class. Default is ``None`` which for binary problem is translated - to 1. For multiclass problems his argument should not be set as we iteratively change it in the - range ``[0, num_classes-1]`` - average: reduction method for multi-class or multi-label problems - - Example: - >>> # binary case - >>> preds = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 1]) - >>> pos_label = 1 - >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, pos_label=pos_label) - >>> _average_precision_compute(preds, target, num_classes, pos_label) - tensor(1.) - - >>> # multiclass case - >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.05, 0.75, 0.05]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> num_classes = 5 - >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes) - >>> _average_precision_compute(preds, target, num_classes, average=None) - [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] - """ - - if average == "micro" and preds.ndim == target.ndim: - preds = preds.flatten() - target = target.flatten() - num_classes = 1 - - precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) - if average == "weighted": - if preds.ndim == target.ndim and target.ndim > 1: - weights = target.sum(dim=0).float() - else: - weights = _bincount(target, minlength=max(num_classes, 2)).float() - weights = weights / torch.sum(weights) - else: - weights = None - return _average_precision_compute_with_precision_recall(precision, recall, num_classes, average, weights) - - -def _average_precision_compute_with_precision_recall( - precision: Tensor, - recall: Tensor, - num_classes: int, - average: Optional[str] = "macro", - weights: Optional[Tensor] = None, -) -> Union[List[Tensor], Tensor]: - """Computes the average precision score from precision and recall. - - Args: - precision: precision values - recall: recall values - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - average: reduction method for multi-class or multi-label problems - weights: weights to use when average='weighted' - - Example: - >>> # binary case - >>> preds = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 1]) - >>> pos_label = 1 - >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, pos_label=pos_label) - >>> precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) - >>> _average_precision_compute_with_precision_recall(precision, recall, num_classes, average=None) - tensor(1.) - - >>> # multiclass case - >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.05, 0.75, 0.05]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> num_classes = 5 - >>> preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes) - >>> precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes) - >>> _average_precision_compute_with_precision_recall(precision, recall, num_classes, average=None) - [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] - """ - - # Return the step function integral - # The following works because the last entry of precision is - # guaranteed to be 1, as returned by precision_recall_curve - if num_classes == 1: - return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) - - res = [] - for p, r in zip(precision, recall): - res.append(-torch.sum((r[1:] - r[:-1]) * p[:-1])) - - # Reduce - if average in ("macro", "weighted"): - res = torch.stack(res) - if torch.isnan(res).any(): - warnings.warn( - "Average precision score for one or more classes was `nan`. Ignoring these classes in average", - UserWarning, - ) - if average == "macro": - return res[~torch.isnan(res)].mean() - weights = torch.ones_like(res) if weights is None else weights - return (res * weights)[~torch.isnan(res)].sum() - if average is None or average == "none": - return res - allowed_average = ("micro", "macro", "weighted", "none", None) - raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}") - - def average_precision( preds: Tensor, target: Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - average: Optional[Literal["macro", "weighted", "none"]] = "macro", - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + task: Literal["binary", "multiclass", "multilabel"], thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, + average: Optional[Literal["macro", "weighted", "none"]] = "macro", ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[List[Tensor], Tensor]: - r"""Average precision. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes the average precision score. + r"""Computes the average precision (AP) score. The AP score summarizes a precision-recall curve as an weighted + mean of precisions at each threshold, with the difference in recall from the previous threshold as weight: - Args: - preds: predictions from model (logits or probabilities) - target: ground truth values - num_classes: integer with number of classes. Not nessesary to provide - for binary problems. - pos_label: integer determining the positive class. Default is ``None`` which for binary problem is translated - to 1. For multiclass problems his argument should not be set as we iteratively change it in the - range ``[0, num_classes-1]`` - average: - defines the reduction that is applied in the case of multiclass and multilabel input. - Should be one of the following: + .. math:: + AP = \sum{n} (R_n - R_{n-1}) P_n - - ``'macro'`` [default]: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support. - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. + where :math:`P_n, R_n` is the respective precision and recall at threshold index :math:`n`. This value is + equivalent to the area under the precision-recall curve (AUPRC). - Returns: - tensor with average precision. If multiclass it returns list - of such tensors, one for each class + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_average_precision`, :func:`multiclass_average_precision` and :func:`multilabel_average_precision` + for the specific details of each argument influence and examples. - Example (binary case): + Legacy Example: >>> from torchmetrics.functional import average_precision - >>> pred = torch.tensor([0, 1, 2, 3]) + >>> pred = torch.tensor([0.0, 1.0, 2.0, 3.0]) >>> target = torch.tensor([0, 1, 1, 1]) - >>> average_precision(pred, target, pos_label=1) + >>> average_precision(pred, target, task="binary") tensor(1.) - Example (multiclass case): >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> average_precision(pred, target, num_classes=5, average=None) - [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] + >>> average_precision(pred, target, task="multiclass", num_classes=5, average=None) + tensor([1.0000, 1.0000, 0.2500, 0.2500, nan]) """ - if task is not None: - if task == "binary": - return binary_average_precision(preds, target, thresholds, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - return multiclass_average_precision( - preds, target, num_classes, average, thresholds, ignore_index, validate_args - ) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_average_precision( - preds, target, num_labels, average, thresholds, ignore_index, validate_args - ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + if task == "binary": + return binary_average_precision(preds, target, thresholds, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_average_precision( + preds, target, num_classes, average, thresholds, ignore_index, validate_args ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes, pos_label, average) - return _average_precision_compute(preds, target, num_classes, pos_label, average) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_average_precision(preds, target, num_labels, average, thresholds, ignore_index, validate_args) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index d8b875c10d1..4545dec0742 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -23,9 +23,6 @@ _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, ) -from torchmetrics.utilities.checks import _input_format_classification -from torchmetrics.utilities.enums import DataType -from torchmetrics.utilities.prints import rank_zero_warn def _binning_bucketize( @@ -316,123 +313,44 @@ def multiclass_calibration_error( return _ce_compute(confidences, accuracies, n_bins, norm) -def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: - """Given a predictions and targets tensor, computes the confidences of the top-1 prediction and records their - correctness. - - Args: - preds: Input ``softmaxed`` predictions. - target: Labels. - - Raises: - ValueError: If the dataset shape is not binary, multiclass, or multidimensional-multiclass. - - Returns: - tuple with confidences and accuracies - """ - _, _, mode = _input_format_classification(preds, target) - - if mode == DataType.BINARY: - if not ((0 <= preds) * (preds <= 1)).all(): - preds = preds.sigmoid() - confidences, accuracies = preds, target - elif mode == DataType.MULTICLASS: - if not ((0 <= preds) * (preds <= 1)).all(): - preds = preds.softmax(dim=1) - confidences, predictions = preds.max(dim=1) - accuracies = predictions.eq(target) - elif mode == DataType.MULTIDIM_MULTICLASS: - # reshape tensors - # for preds, move the class dimension to the final axis and flatten the rest - confidences, predictions = torch.transpose(preds, 1, -1).flatten(0, -2).max(dim=1) - # for targets, just flatten the target - accuracies = predictions.eq(target.flatten()) - else: - raise ValueError( - f"Calibration error is not well-defined for data with size {preds.size()} and targets {target.size()}." - ) - # must be cast to float for ddp allgather to work - return confidences.float(), accuracies.float() - - def calibration_error( preds: Tensor, target: Tensor, + task: Literal["binary", "multiclass"] = None, n_bins: int = 15, norm: Literal["l1", "l2", "max"] = "l1", - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_classes: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Calibration Error. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - `Computes the Top-label Calibration Error`_ + r"""`Computes the Top-label Calibration Error`_. The expected calibration error can be used to quantify how well + a given model is calibrated e.g. how well the predicted output probabilities of the model matches the actual + probabilities of the ground truth distribution. Three different norms are implemented, each corresponding to variations on the calibration error metric. - L1 norm (Expected Calibration Error) - .. math:: - \text{ECE} = \sum_i^N b_i \|(p_i - c_i)\| - - Infinity norm (Maximum Calibration Error) + \text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)} .. math:: - \text{MCE} = \max_{i} (p_i - c_i) - - L2 norm (Root Mean Square Calibration Error) + \text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)} .. math:: - \text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2} - - Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, - :math:`c_i` is the average confidence of predictions in bin :math:`i`, and - :math:`b_i` is the fraction of data points in bin :math:`i`. + \text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)} - .. note: - L2-norm debiasing is not yet supported. + Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, :math:`c_i` is the average confidence of + predictions in bin :math:`i`, and :math:`b_i` is the fraction of data points in bin :math:`i`. Bins are constructed + in an uniform way in the [0,1] range. - Args: - preds: Model output probabilities. - target: Ground-truth target class labels. - n_bins: Number of bins to use when computing t. - norm: Norm used to compare empirical and expected probability bins. - Defaults to "l1", or Expected Calibration Error. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'`` or ``'multiclass'``. See the documentation of + :func:`binary_calibration_error` and :func:`multiclass_calibration_error` for the specific details of + each argument influence and examples. """ - if task is not None: - assert norm is not None - if task == "binary": - return binary_calibration_error(preds, target, n_bins, norm, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - return multiclass_calibration_error(preds, target, num_classes, n_bins, norm, ignore_index, validate_args) - raise ValueError(f"Expected argument `task` to either be `'binary'`, `'multiclass'` but got {task}") - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - - if norm not in ("l1", "l2", "max"): - raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") - - if not isinstance(n_bins, int) or n_bins <= 0: - raise ValueError(f"Expected argument `n_bins` to be a int larger than 0 but got {n_bins}") - - confidences, accuracies = _ce_update(preds, target) - - bin_boundaries = torch.linspace(0, 1, n_bins + 1, dtype=torch.float, device=preds.device) - - return _ce_compute(confidences, accuracies, bin_boundaries, norm=norm) + assert norm is not None + if task == "binary": + return binary_calibration_error(preds, target, n_bins, norm, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_calibration_error(preds, target, num_classes, n_bins, norm, ignore_index, validate_args) + raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multiclass'` but got {task}") diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index 10565471dfc..b4d9c1217a0 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -22,14 +22,11 @@ _binary_confusion_matrix_format, _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_update, - _confusion_matrix_compute, - _confusion_matrix_update, _multiclass_confusion_matrix_arg_validation, _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, _multiclass_confusion_matrix_update, ) -from torchmetrics.utilities.prints import rank_zero_warn def _cohen_kappa_reduce(confmat: Tensor, weights: Optional[Literal["linear", "quadratic", "none"]] = None) -> Tensor: @@ -230,78 +227,17 @@ class labels. return _cohen_kappa_reduce(confmat, weights) -_cohen_kappa_update = _confusion_matrix_update - - -def _cohen_kappa_compute(confmat: Tensor, weights: Optional[str] = None) -> Tensor: - """Computes Cohen's kappa based on the weighting type. - - Args: - confmat: Confusion matrix without normalization - weights: Weighting type to calculate the score. Choose from: - - - ``None`` or ``'none'``: no weighting - - ``'linear'``: linear weighting - - ``'quadratic'``: quadratic weighting - - Example: - >>> target = torch.tensor([1, 1, 0, 0]) - >>> preds = torch.tensor([0, 1, 0, 0]) - >>> confmat = _cohen_kappa_update(preds, target, num_classes=2) - >>> _cohen_kappa_compute(confmat) - tensor(0.5000) - """ - - confmat = _confusion_matrix_compute(confmat) - confmat = confmat.float() if not confmat.is_floating_point() else confmat - n_classes = confmat.shape[0] - sum0 = confmat.sum(dim=0, keepdim=True) - sum1 = confmat.sum(dim=1, keepdim=True) - expected = sum1 @ sum0 / sum0.sum() # outer product - - if weights is None: - w_mat = torch.ones_like(confmat).flatten() - w_mat[:: n_classes + 1] = 0 - w_mat = w_mat.reshape(n_classes, n_classes) - elif weights in ("linear", "quadratic"): - w_mat = torch.zeros_like(confmat) - w_mat += torch.arange(n_classes, dtype=w_mat.dtype, device=w_mat.device) - if weights == "linear": - w_mat = torch.abs(w_mat - w_mat.T) - else: - w_mat = torch.pow(w_mat - w_mat.T, 2.0) - else: - raise ValueError( - f"Received {weights} for argument ``weights`` but should be either" " None, 'linear' or 'quadratic'" - ) - - k = torch.sum(w_mat * confmat) / torch.sum(w_mat * expected) - return 1 - k - - def cohen_kappa( preds: Tensor, target: Tensor, - num_classes: int, - weights: Optional[Literal["linear", "quadratic", "none"]] = None, + task: Literal["binary", "multiclass"], threshold: float = 0.5, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Cohen's kappa. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - - Calculates `Cohen's kappa score`_ that measures inter-annotator agreement. - - It is defined as + r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement. It is defined as. .. math:: \kappa = (p_o - p_e) / (1 - p_e) @@ -311,43 +247,20 @@ def cohen_kappa( :math:`p_e` is estimated using a per-annotator empirical prior over the class labels. - Args: - preds: (float or long tensor), Either a ``(N, ...)`` tensor with labels or - ``(N, C, ...)`` where C is the number of classes, tensor with labels/probabilities - target: ``target`` (long tensor), tensor with shape ``(N, ...)`` with ground true labels - num_classes: Number of classes in the dataset. - weights: Weighting type to calculate the score. Choose from: - - - ``None`` or ``'none'``: no weighting - - ``'linear'``: linear weighting - - ``'quadratic'``: quadratic weighting - - threshold: Threshold value for binary or multi-label probabilities. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'`` or ``'multiclass'``. See the documentation of + :func:`binary_cohen_kappa` and :func:`multiclass_cohen_kappa` for the specific details of + each argument influence and examples. - Example: - >>> from torchmetrics.functional import cohen_kappa + Legacy Example: >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) - >>> cohen_kappa(preds, target, num_classes=2) + >>> cohen_kappa(preds, target, task="multiclass", num_classes=2) tensor(0.5000) """ - if task is not None: - if task == "binary": - return binary_cohen_kappa(preds, target, threshold, weights, ignore_index, validate_args) - if task == "multiclass": - return multiclass_cohen_kappa(preds, target, num_classes, weights, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - - confmat = _cohen_kappa_update(preds, target, num_classes, threshold) - return _cohen_kappa_compute(confmat, weights) + if task == "binary": + return binary_cohen_kappa(preds, target, threshold, weights, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_cohen_kappa(preds, target, num_classes, weights, ignore_index, validate_args) + raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multiclass'` but got {task}") diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 086df2673df..3d5b76e1ad5 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -17,9 +17,8 @@ from torch import Tensor from typing_extensions import Literal -from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification +from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.data import _bincount -from torchmetrics.utilities.enums import DataType from torchmetrics.utilities.prints import rank_zero_warn @@ -592,200 +591,57 @@ def multilabel_confusion_matrix( return _multilabel_confusion_matrix_compute(confmat, normalize) -def _confusion_matrix_update( - preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5, multilabel: bool = False -) -> Tensor: - """Updates and returns confusion matrix (without any normalization) based on the mode of the input. - - Args: - preds: Predicted tensor - target: Ground truth tensor - threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the - case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - multilabel: determines if data is multilabel or not. - """ - - preds, target, mode = _input_format_classification(preds, target, threshold) - if mode not in (DataType.BINARY, DataType.MULTILABEL): - preds = preds.argmax(dim=1) - target = target.argmax(dim=1) - if multilabel: - unique_mapping = ((2 * target + preds) + 4 * torch.arange(num_classes, device=preds.device)).flatten() - minlength = 4 * num_classes - else: - unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long) - minlength = num_classes**2 - - bins = _bincount(unique_mapping, minlength=minlength) - if multilabel: - confmat = bins.reshape(num_classes, 2, 2) - else: - confmat = bins.reshape(num_classes, num_classes) - return confmat - - -def _confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: - """Computes confusion matrix based on the normalization mode. - - Args: - confmat: Confusion matrix without normalization - normalize: Normalization mode for confusion matrix. Choose from: - - - ``None`` or ``'none'``: no normalization (default) - - ``'true'``: normalization over the targets (most commonly used) - - ``'pred'``: normalization over the predictions - - ``'all'``: normalization over the whole matrix - - Example: - >>> # binary case - >>> target = torch.tensor([1, 1, 0, 0]) - >>> preds = torch.tensor([0, 1, 0, 0]) - >>> confmat = _confusion_matrix_update(preds, target, num_classes=2) - >>> _confusion_matrix_compute(confmat) - tensor([[2, 0], - [1, 1]]) - - >>> # multiclass case - >>> target = torch.tensor([2, 1, 0, 0]) - >>> preds = torch.tensor([2, 1, 0, 1]) - >>> confmat = _confusion_matrix_update(preds, target, num_classes=3) - >>> _confusion_matrix_compute(confmat) - tensor([[1, 1, 0], - [0, 1, 0], - [0, 0, 1]]) - - >>> # multilabel case - >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) - >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) - >>> confmat = _confusion_matrix_update(preds, target, num_classes=3, multilabel=True) - >>> _confusion_matrix_compute(confmat) - tensor([[[1, 0], [0, 1]], - [[1, 0], [1, 0]], - [[0, 1], [0, 1]]]) - """ - - allowed_normalize = ("true", "pred", "all", "none", None) - if normalize not in allowed_normalize: - raise ValueError(f"Argument average needs to one of the following: {allowed_normalize}") - if normalize is not None and normalize != "none": - confmat = confmat.float() if not confmat.is_floating_point() else confmat - if normalize == "true": - confmat = confmat / confmat.sum(axis=1, keepdim=True) - elif normalize == "pred": - confmat = confmat / confmat.sum(axis=0, keepdim=True) - elif normalize == "all": - confmat = confmat / confmat.sum() - - nan_elements = confmat[torch.isnan(confmat)].nelement() - if nan_elements != 0: - confmat[torch.isnan(confmat)] = 0 - rank_zero_warn(f"{nan_elements} nan values found in confusion matrix have been replaced with zeros.") - return confmat - - def confusion_matrix( preds: Tensor, target: Tensor, - num_classes: int, - normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - multilabel: bool = False, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Confusion matrix. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - + r"""Computes the `confusion matrix`_. - Computes the `confusion matrix`_. Works with binary, - multiclass, and multilabel data. Accepts probabilities or logits from a model output or integer class - values in prediction. Works with multi-dimensional preds and target, but it should be noted that - additional dimensions will be flattened. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_confusion_matrix`, :func:`multiclass_confusion_matrix` and :func:`multilabel_confusion_matrix` for + the specific details of each argument influence and examples. - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities or logits. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - If working with multilabel data, setting the ``is_multilabel`` argument to ``True`` will make sure that a - `confusion matrix gets calculated per label`_. - - Args: - preds: (float or long tensor), Either a ``(N, ...)`` tensor with labels or - ``(N, C, ...)`` where C is the number of classes, tensor with labels/logits/probabilities - target: ``target`` (long tensor), tensor with shape ``(N, ...)`` with ground true labels - num_classes: Number of classes in the dataset. - normalize: Normalization mode for confusion matrix. Choose from: - - - ``None`` or ``'none'``: no normalization (default) - - ``'true'``: normalization over the targets (most commonly used) - - ``'pred'``: normalization over the predictions - - ``'all'``: normalization over the whole matrix - - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - - multilabel: - determines if data is multilabel or not. - - Example (binary data): + Legacy Example: >>> from torchmetrics import ConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) - >>> confmat = ConfusionMatrix(num_classes=2) + >>> confmat = ConfusionMatrix(task="binary") >>> confmat(preds, target) tensor([[2, 0], [1, 1]]) - Example (multiclass data): >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) - >>> confmat = ConfusionMatrix(num_classes=3) + >>> confmat = ConfusionMatrix(task="multiclass", num_classes=3) >>> confmat(preds, target) tensor([[1, 1, 0], [0, 1, 0], [0, 0, 1]]) - Example (multilabel data): >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) - >>> confmat = ConfusionMatrix(num_classes=3, multilabel=True) + >>> confmat = ConfusionMatrix(task="multilabel", num_labels=3) >>> confmat(preds, target) tensor([[[1, 0], [0, 1]], [[1, 0], [1, 0]], [[0, 1], [0, 1]]]) """ - if task is not None: - if task == "binary": - return binary_confusion_matrix(preds, target, threshold, normalize, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - return multiclass_confusion_matrix(preds, target, num_classes, normalize, ignore_index, validate_args) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_confusion_matrix( - preds, target, num_labels, threshold, normalize, ignore_index, validate_args - ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - confmat = _confusion_matrix_update(preds, target, num_classes, threshold, multilabel) - return _confusion_matrix_compute(confmat, normalize) + if task == "binary": + return binary_confusion_matrix(preds, target, threshold, normalize, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_confusion_matrix(preds, target, num_classes, normalize, ignore_index, validate_args) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_confusion_matrix(preds, target, num_labels, threshold, normalize, ignore_index, validate_args) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/functional/classification/dice.py b/src/torchmetrics/functional/classification/dice.py index 5379d1d45eb..301321bfaeb 100644 --- a/src/torchmetrics/functional/classification/dice.py +++ b/src/torchmetrics/functional/classification/dice.py @@ -38,15 +38,6 @@ def _dice_compute( average: Defines the reduction that is applied mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the ``average`` parameter) - - Example: - >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update - >>> from torchmetrics.functional.classification.dice import _dice_compute - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='micro') - >>> _dice_compute(tp, fp, fn, average='micro', mdmc_average=None) - tensor(0.2500) """ numerator = 2 * tp denominator = 2 * tp + fp + fn diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index 1cdb1261a3d..6c7aebba8bc 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -30,13 +30,8 @@ _multilabel_stat_scores_format, _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, - _reduce_stat_scores, - _stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide -from torchmetrics.utilities.enums import AverageMethod as AvgMethod -from torchmetrics.utilities.enums import MDMCAverageMethod -from torchmetrics.utilities.prints import rank_zero_warn def _fbeta_reduce( @@ -700,401 +695,99 @@ def multilabel_f1_score( ) -def _fbeta_compute( - tp: Tensor, - fp: Tensor, - tn: Tensor, - fn: Tensor, - beta: float, - ignore_index: Optional[int], - average: str, - mdmc_average: Optional[str], -) -> Tensor: - """Computes f_beta metric from stat scores: true positives, false positives, true negatives, false negatives. - - Args: - tp: True positives - fp: False positives - tn: True negatives - fn: False negatives - beta: The parameter `beta` (which determines the weight of recall in the combined score) - ignore_index: Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method - average: Defines the reduction that is applied - mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter) - - Example: - >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update - >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) - >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> tp, fp, tn, fn = _stat_scores_update( - ... preds, - ... target, - ... reduce='micro', - ... num_classes=3, - ... ) - >>> _fbeta_compute(tp, fp, tn, fn, beta=0.5, ignore_index=None, average='micro', mdmc_average=None) - tensor(0.3333) - """ - if average == AvgMethod.MICRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: - mask = tp >= 0 - precision = _safe_divide(tp[mask].sum().float(), (tp[mask] + fp[mask]).sum()) - recall = _safe_divide(tp[mask].sum().float(), (tp[mask] + fn[mask]).sum()) - else: - precision = _safe_divide(tp.float(), tp + fp) - recall = _safe_divide(tp.float(), tp + fn) - - num = (1 + beta**2) * precision * recall - denom = beta**2 * precision + recall - denom[denom == 0.0] = 1.0 # avoid division by 0 - - # if classes matter and a given class is not present in both the preds and the target, - # computing the score for this class is meaningless, thus they should be ignored - if average == AvgMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: - # a class is not present if there exists no TPs, no FPs, and no FNs - meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu() - if ignore_index is None: - ignore_index = meaningless_indeces - else: - ignore_index = torch.unique(torch.cat((meaningless_indeces, torch.tensor([[ignore_index]])))) - - if ignore_index is not None: - if average not in (AvgMethod.MICRO, AvgMethod.SAMPLES) and mdmc_average == MDMCAverageMethod.SAMPLEWISE: - num[..., ignore_index] = -1 - denom[..., ignore_index] = -1 - elif average not in (AvgMethod.MICRO, AvgMethod.SAMPLES): - num[ignore_index, ...] = -1 - denom[ignore_index, ...] = -1 - - if average == AvgMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: - cond = (tp + fp + fn == 0) | (tp + fp + fn == -3) - num = num[~cond] - denom = denom[~cond] - - return _reduce_stat_scores( - numerator=num, - denominator=denom, - weights=None if average != AvgMethod.WEIGHTED else tp + fn, - average=average, - mdmc_average=mdmc_average, - ) - - def fbeta_score( preds: Tensor, target: Tensor, + task: Literal["binary", "multiclass", "multilabel"], beta: float = 1.0, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - num_classes: Optional[int] = None, threshold: float = 0.5, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""F-Beta score. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - - Computes f_beta metric. + r"""Computes `F-score`_ metric: .. math:: F_{\beta} = (1 + \beta^2) * \frac{\text{precision} * \text{recall}} {(\beta^2 * \text{precision}) + \text{recall}} - Works with binary, multiclass, and multilabel data. - Accepts probabilities or logits from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label logits or probabilities. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - The reduction method (how the precision scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. - - Args: - preds: Predictions from model (probabilities, logits or labels) - target: Ground truth values - beta: beta coefficient - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, - the value for the class will be ``nan``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - top_k: - Number of highest probability or logit score predictions considered to find the correct label, - relevant only for (multi-dimensional) multi-class inputs. The - default value (``None``) will be interpreted as 1 for these inputs. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_fbeta_score`, :func:`multiclass_fbeta_score` and :func:`multilabel_fbeta_score` for the specific + details of each argument influence and examples. - Should be left at default (``None``) for all other types of inputs. - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - Return: - The shape of the returned tensor depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes - - Example: - >>> from torchmetrics.functional import fbeta_score + Legacy Example: >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> fbeta_score(preds, target, num_classes=3, beta=0.5) + >>> fbeta_score(preds, target, task="multiclass", num_classes=3, beta=0.5) tensor(0.3333) """ - if task is not None: - assert multidim_average is not None - if task == "binary": - return binary_fbeta_score(preds, target, beta, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return multiclass_fbeta_score( - preds, target, beta, num_classes, average, top_k, multidim_average, ignore_index, validate_args - ) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_fbeta_score( - preds, target, beta, num_labels, threshold, average, multidim_average, ignore_index, validate_args - ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + assert multidim_average is not None + if task == "binary": + return binary_fbeta_score(preds, target, beta, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_fbeta_score( + preds, target, beta, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_fbeta_score( + preds, target, beta, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - allowed_average = list(AvgMethod) - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - if mdmc_average is not None and MDMCAverageMethod.from_str(mdmc_average) is None: - raise ValueError(f"The `mdmc_average` has to be one of {list(MDMCAverageMethod)}, got {mdmc_average}.") - - if average in [AvgMethod.MACRO, AvgMethod.WEIGHTED, AvgMethod.NONE] and (not num_classes or num_classes < 1): - raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - reduce = AvgMethod.MACRO if average in [AvgMethod.WEIGHTED, AvgMethod.NONE] else average - tp, fp, tn, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce=mdmc_average, - threshold=threshold, - num_classes=num_classes, - top_k=top_k, - multiclass=multiclass, - ignore_index=ignore_index, + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - return _fbeta_compute(tp, fp, tn, fn, beta, ignore_index, average, mdmc_average) - def f1_score( preds: Tensor, target: Tensor, - beta: float = 1.0, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - num_classes: Optional[int] = None, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""F1 score. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes F1 metric. F1 metrics correspond to equally weighted average of the precision and recall scores. - - Works with binary, multiclass, and multilabel data. - Accepts probabilities or logits from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities or logits. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - The reduction method (how the precision scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. + r"""Computes F-1 score: - Args: - preds: Predictions from model (probabilities, logits or labels) - target: Ground truth values - beta: it is ignored - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, - the value for the class will be ``nan``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - top_k: - Number of highest probability or logit score predictions considered to find the correct label, - relevant only for (multi-dimensional) multi-class inputs. The - default value (``None``) will be interpreted as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. - - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - Return: - The shape of the returned tensor depends on the ``average`` parameter + .. math:: + F_{1} = 2\frac{\text{precision} * \text{recall}}{(\text{precision}) + \text{recall}} - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_f1_score`, :func:`multiclass_f1_score` and :func:`multilabel_f1_score` for the specific + details of each argument influence and examples. - Example: - >>> from torchmetrics.functional import f1_score + Legacy Example: >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) - >>> f1_score(preds, target, num_classes=3) + >>> f1_score(preds, target, task="multiclass", num_classes=3) tensor(0.3333) """ - if task is not None: - assert multidim_average is not None - if task == "binary": - return binary_f1_score(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return multiclass_f1_score( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args - ) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_f1_score( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args - ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + assert multidim_average is not None + if task == "binary": + return binary_f1_score(preds, target, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_f1_score( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_f1_score( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - return fbeta_score( - preds, target, 1.0, average, mdmc_average, ignore_index, num_classes, threshold, top_k, multiclass + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) diff --git a/src/torchmetrics/functional/classification/hamming.py b/src/torchmetrics/functional/classification/hamming.py index a8e75f1b89d..84da16735b1 100644 --- a/src/torchmetrics/functional/classification/hamming.py +++ b/src/torchmetrics/functional/classification/hamming.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional import torch from torch import Tensor @@ -31,9 +31,7 @@ _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, ) -from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.compute import _safe_divide -from torchmetrics.utilities.prints import rank_zero_warn def _hamming_distance_reduce( @@ -388,70 +386,20 @@ def multilabel_hamming_distance( return _hamming_distance_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True) -def _hamming_distance_update( - preds: Tensor, - target: Tensor, - threshold: float = 0.5, -) -> Tuple[Tensor, int]: - """Returns the number of positions where prediction equals target, and number of predictions. - - Args: - preds: Predicted tensor - target: Ground truth tensor - threshold: Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - """ - - preds, target, _ = _input_format_classification(preds, target, threshold=threshold) - - correct = (preds == target).sum() - total = preds.numel() - - return correct, total - - -def _hamming_distance_compute(correct: Tensor, total: Union[int, Tensor]) -> Tensor: - """Computes the Hamming distance. - - Args: - correct: Number of positions where prediction equals target - total: Total number of predictions - - Example: - >>> target = torch.tensor([[0, 1], [1, 1]]) - >>> preds = torch.tensor([[0, 1], [0, 1]]) - >>> correct, total = _hamming_distance_update(preds, target) - >>> _hamming_distance_compute(correct, total) - tensor(0.2500) - """ - - return 1 - correct.float() / total - - def hamming_distance( preds: Tensor, target: Tensor, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - top_k: int = 1, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Hamming distance. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes the average `Hamming distance`_ (also - known as Hamming loss) between targets and predictions: + r"""Computes the average `Hamming distance`_ (also known as Hamming loss): .. math:: \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) @@ -460,51 +408,31 @@ def hamming_distance( and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that tensor. - This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it - treats each possible label separately - meaning that, for example, multi-class data is - treated as if it were multi-label. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_hamming_distance`, :func:`multiclass_hamming_distance` and :func:`multilabel_hamming_distance` for + the specific details of each argument influence and examples. - Args: - preds: Predictions from model (probabilities, logits or labels) - target: Ground truth - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - - Example: - >>> from torchmetrics.functional import hamming_distance + Legacy Example: >>> target = torch.tensor([[0, 1], [1, 1]]) >>> preds = torch.tensor([[0, 1], [0, 1]]) - >>> hamming_distance(preds, target) + >>> hamming_distance(preds, target, task="binary") tensor(0.2500) """ - if task is not None: - assert multidim_average is not None - if task == "binary": - return binary_hamming_distance(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return multiclass_hamming_distance( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args - ) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_hamming_distance( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args - ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + assert multidim_average is not None + if task == "binary": + return binary_hamming_distance(preds, target, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_hamming_distance( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_hamming_distance( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - - correct, total = _hamming_distance_update(preds, target, threshold) - return _hamming_distance_compute(correct, total) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index e7ae53ec3af..cb7d441243d 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import torch from torch import Tensor, tensor @@ -23,10 +23,7 @@ _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, ) -from torchmetrics.utilities.checks import _input_squeeze from torchmetrics.utilities.data import to_onehot -from torchmetrics.utilities.enums import DataType, EnumStr -from torchmetrics.utilities.prints import rank_zero_warn def _hinge_loss_compute(measure: Tensor, total: Tensor) -> Tensor: @@ -243,244 +240,43 @@ def multiclass_hinge_loss( return _hinge_loss_compute(measures, total) -class MulticlassMode(EnumStr): - """Enum to represent possible multiclass modes of hinge. - - >>> "Crammer-Singer" in list(MulticlassMode) - True - """ - - CRAMMER_SINGER = "crammer-singer" - ONE_VS_ALL = "one-vs-all" - - -def _check_shape_and_type_consistency_hinge( - preds: Tensor, - target: Tensor, -) -> DataType: - """Checks shape and type of ``preds`` and ``target`` and returns mode of the input tensors. - - Args: - preds: Predicted tensor - target: Ground truth tensor - - Raises: - `ValueError`: if ``target`` is not one dimensional - `ValueError`: if ``preds`` and ``target`` do not have the same shape in the first dimension - `ValueError`: if ``preds`` is neither one nor two-dimensional - """ - - if target.ndim > 1: - raise ValueError( - f"The `target` should be one dimensional, got `target` with shape={target.shape}.", - ) - - if preds.ndim == 1: - if preds.shape != target.shape: - raise ValueError( - "The `preds` and `target` should have the same shape,", - f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", - ) - mode = DataType.BINARY - elif preds.ndim == 2: - if preds.shape[0] != target.shape[0]: - raise ValueError( - "The `preds` and `target` should have the same shape in the first dimension,", - f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", - ) - mode = DataType.MULTICLASS - else: - raise ValueError(f"The `preds` should be one or two dimensional, got `preds` with shape={preds.shape}.") - return mode - - -def _hinge_update( - preds: Tensor, - target: Tensor, - squared: bool = False, - multiclass_mode: Optional[Union[str, MulticlassMode]] = None, -) -> Tuple[Tensor, Tensor]: - """Updates and returns sum over Hinge loss scores for each observation and the total number of observations. - - Args: - preds: Predicted tensor - target: Ground truth tensor - squared: If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss. - multiclass_mode: - Which approach to use for multi-class inputs (has no effect in the binary case). ``None`` (default), - ``MulticlassMode.CRAMMER_SINGER`` or ``"crammer-singer"``, uses the Crammer Singer multi-class hinge loss. - ``MulticlassMode.ONE_VS_ALL`` or ``"one-vs-all"`` computes the hinge loss in a one-vs-all fashion. - """ - preds, target = _input_squeeze(preds, target) - - mode = _check_shape_and_type_consistency_hinge(preds, target) - - if mode == DataType.MULTICLASS: - target = to_onehot(target, max(2, preds.shape[1])).bool() - - if mode == DataType.MULTICLASS and (multiclass_mode is None or multiclass_mode == MulticlassMode.CRAMMER_SINGER): - margin = preds[target] - margin -= torch.max(preds[~target].view(preds.shape[0], -1), dim=1)[0] - elif mode == DataType.BINARY or multiclass_mode == MulticlassMode.ONE_VS_ALL: - target = target.bool() - margin = torch.zeros_like(preds) - margin[target] = preds[target] - margin[~target] = -preds[~target] - else: - raise ValueError( - "The `multiclass_mode` should be either None / 'crammer-singer' / MulticlassMode.CRAMMER_SINGER" - "(default) or 'one-vs-all' / MulticlassMode.ONE_VS_ALL," - f" got {multiclass_mode}." - ) - - measures = 1 - margin - measures = torch.clamp(measures, 0) - - if squared: - measures = measures.pow(2) - - total = tensor(target.shape[0], device=target.device) - return measures.sum(dim=0), total - - -def _hinge_compute(measure: Tensor, total: Tensor) -> Tensor: - """Computes mean Hinge loss. - - Args: - measure: Sum over hinge losses for each observation - total: Number of observations - - Example: - >>> # binary case - >>> target = torch.tensor([0, 1, 1]) - >>> preds = torch.tensor([-2.2, 2.4, 0.1]) - >>> measure, total = _hinge_update(preds, target) - >>> _hinge_compute(measure, total) - tensor(0.3000) - - >>> # multiclass case - >>> target = torch.tensor([0, 1, 2]) - >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) - >>> measure, total = _hinge_update(preds, target) - >>> _hinge_compute(measure, total) - tensor(2.9000) - - >>> # multiclass one-vs-all mode case - >>> target = torch.tensor([0, 1, 2]) - >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) - >>> measure, total = _hinge_update(preds, target, multiclass_mode="one-vs-all") - >>> _hinge_compute(measure, total) - tensor([2.2333, 1.5000, 1.2333]) - """ - - return measure / total - - def hinge_loss( preds: Tensor, target: Tensor, - squared: bool = False, - multiclass_mode: Optional[Literal["crammer-singer", "one-vs-all"]] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + task: Literal["binary", "multiclass"], num_classes: Optional[int] = None, + squared: bool = False, + multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Hinge loss. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. + r"""Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs). - Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs). + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'`` or ``'multiclass'``. See the documentation of + :func:`binary_hinge_loss` and :func:`multiclass_hinge_loss` for the specific details of + each argument influence and examples. - In the binary case it is defined as: - - .. math:: - \text{Hinge loss} = \max(0, 1 - y \times \hat{y}) - - Where :math:`y \in {-1, 1}` is the target, and :math:`\hat{y} \in \mathbb{R}` is the prediction. - - In the multi-class case, when ``multiclass_mode=None`` (default), ``multiclass_mode=MulticlassMode.CRAMMER_SINGER`` - or ``multiclass_mode="crammer-singer"``, this metric will compute the multi-class hinge loss defined by Crammer and - Singer as: - - .. math:: - \text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right) - - Where :math:`y \in {0, ..., \mathrm{C}}` is the target class (where :math:`\mathrm{C}` is the number of classes), - and :math:`\hat{y} \in \mathbb{R}^\mathrm{C}` is the predicted output per class. - - In the multi-class case when ``multiclass_mode=MulticlassMode.ONE_VS_ALL`` or ``multiclass_mode='one-vs-all'``, this - metric will use a one-vs-all approach to compute the hinge loss, giving a vector of C outputs where each entry pits - that class against all remaining classes. - - This metric can optionally output the mean of the squared hinge loss by setting ``squared=True`` - - Only accepts inputs with preds shape of (N) (binary) or (N, C) (multi-class) and target shape of (N). - - Args: - preds: Predictions from model (as float outputs from decision function). - target: Ground truth labels. - squared: - If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss (default). - multiclass_mode: - Which approach to use for multi-class inputs (has no effect in the binary case). ``None`` (default), - ``MulticlassMode.CRAMMER_SINGER`` or ``"crammer-singer"``, uses the Crammer Singer multi-class hinge loss. - ``MulticlassMode.ONE_VS_ALL`` or ``"one-vs-all"`` computes the hinge loss in a one-vs-all fashion. - - Raises: - ValueError: - If preds shape is not of size (N) or (N, C). - ValueError: - If target shape is not of size (N). - ValueError: - If ``multiclass_mode`` is not: None, ``MulticlassMode.CRAMMER_SINGER``, ``"crammer-singer"``, - ``MulticlassMode.ONE_VS_ALL`` or ``"one-vs-all"``. - - Example (binary case): + Legacy Example: >>> import torch - >>> from torchmetrics.functional import hinge_loss >>> target = torch.tensor([0, 1, 1]) - >>> preds = torch.tensor([-2.2, 2.4, 0.1]) - >>> hinge_loss(preds, target) - tensor(0.3000) + >>> preds = torch.tensor([0.5, 0.7, 0.1]) + >>> hinge_loss(preds, target, task="binary") + tensor(0.9000) - Example (default / multiclass case): >>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) - >>> hinge_loss(preds, target) - tensor(2.9000) + >>> hinge_loss(preds, target, task="multiclass", num_classes=3) + tensor(1.5551) - Example (multiclass example, one vs all mode): >>> target = torch.tensor([0, 1, 2]) >>> preds = torch.tensor([[-1.0, 0.9, 0.2], [0.5, -1.1, 0.8], [2.2, -0.5, 0.3]]) - >>> hinge_loss(preds, target, multiclass_mode="one-vs-all") - tensor([2.2333, 1.5000, 1.2333]) + >>> hinge_loss(preds, target, task="multiclass", num_classes=3, multiclass_mode="one-vs-all") + tensor([1.3743, 1.1945, 1.2359]) """ - if task is not None: - if task == "binary": - return binary_hinge_loss(preds, target, squared, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - assert multiclass_mode is not None - return multiclass_hinge_loss( - preds, target, num_classes, squared, multiclass_mode, ignore_index, validate_args - ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - measure, total = _hinge_update(preds, target, squared=squared, multiclass_mode=multiclass_mode) - return _hinge_compute(measure, total) + if task == "binary": + return binary_hinge_loss(preds, target, squared, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_hinge_loss(preds, target, num_classes, squared, multiclass_mode, ignore_index, validate_args) + raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multilabel'` but got {task}") diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 99c6bf355a7..d9c70be4169 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -22,7 +22,6 @@ _binary_confusion_matrix_format, _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_update, - _confusion_matrix_update, _multiclass_confusion_matrix_arg_validation, _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, @@ -33,7 +32,6 @@ _multilabel_confusion_matrix_update, ) from torchmetrics.utilities.compute import _safe_divide -from torchmetrics.utilities.prints import rank_zero_warn def _jaccard_index_reduce( @@ -298,178 +296,44 @@ def multilabel_jaccard_index( return _jaccard_index_reduce(confmat, average=average) -def _jaccard_from_confmat( - confmat: Tensor, - num_classes: int, - average: Optional[str] = "macro", - ignore_index: Optional[int] = None, - absent_score: float = 0.0, -) -> Tensor: - """Computes the intersection over union from confusion matrix. - - Args: - confmat: Confusion matrix without normalization - num_classes: Number of classes for a given prediction and target tensor - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'macro'`` [default]: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'micro'``: Calculate the metric globally, across all samples and classes. - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. Note that if a given class doesn't occur in the - `preds` or `target`, the value for the class will be ``nan``. - - ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. - absent_score: score to use for an individual class, if no instances of the class index were present in `pred` - AND no instances of the class index were present in `target`. - """ - allowed_average = ["micro", "macro", "weighted", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - # Remove the ignored class index from the scores. - if ignore_index is not None and 0 <= ignore_index < num_classes: - confmat[ignore_index] = 0.0 - - if average == "none" or average is None: - intersection = torch.diag(confmat) - union = confmat.sum(0) + confmat.sum(1) - intersection - - # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. - scores = intersection.float() / union.float() - scores = scores.where(union != 0, torch.tensor(absent_score, dtype=scores.dtype, device=scores.device)) - - if ignore_index is not None and 0 <= ignore_index < num_classes: - scores = torch.cat( - [ - scores[:ignore_index], - scores[ignore_index + 1 :], - ] - ) - return scores - - if average == "macro": - scores = _jaccard_from_confmat( - confmat, num_classes, average="none", ignore_index=ignore_index, absent_score=absent_score - ) - return torch.mean(scores) - - if average == "micro": - intersection = torch.sum(torch.diag(confmat)) - union = torch.sum(torch.sum(confmat, dim=1) + torch.sum(confmat, dim=0) - torch.diag(confmat)) - return intersection.float() / union.float() - - weights = torch.sum(confmat, dim=1).float() / torch.sum(confmat).float() - scores = _jaccard_from_confmat( - confmat, num_classes, average="none", ignore_index=ignore_index, absent_score=absent_score - ) - return torch.sum(weights * scores) - - def jaccard_index( preds: Tensor, target: Tensor, - num_classes: int, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", - ignore_index: Optional[int] = None, - absent_score: float = 0.0, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Jaccard index. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes `Jaccard index`_ + r"""Calculates the Jaccard index. The `Jaccard index`_ (also known as the intersetion over + union or jaccard similarity coefficient) is an statistic that can be used to determine the similarity and + diversity of a sample set. It is defined as the size of the intersection divided by the union of the sample + sets: .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} - Where: :math:`A` and :math:`B` are both tensors of the same size, - containing integer class values. They may be subject to conversion from - input data (see description below). - - Note that it is different from box IoU. - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_jaccard_index`, :func:`multiclass_jaccard_index` and :func:`multilabel_jaccard_index` for + the specific details of each argument influence and examples. - If pred has an extra dimension as in the case of multi-class scores we - perform an argmax on ``dim=1``. - - Args: - preds: tensor containing predictions from model (probabilities, or labels) with shape ``[N, d1, d2, ...]`` - target: tensor containing ground truth labels with shape ``[N, d1, d2, ...]`` - num_classes: Specify the number of classes - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'macro'`` [default]: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'micro'``: Calculate the metric globally, across all samples and classes. - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. Note that if a given class doesn't occur in the - `preds` or `target`, the value for the class will be ``nan``. - - ignore_index: optional int specifying a target class to ignore. If given, - this class index does not contribute to the returned score, regardless - of reduction method. Has no effect if given an int that is not in the - range ``[0, num_classes-1]``, where num_classes is either given or derived - from pred and target. By default, no index is ignored, and all classes are used. - absent_score: score to use for an individual class, if no instances of - the class index were present in ``preds`` AND no instances of the class - index were present in ``target``. For example, if we have 3 classes, - [0, 0] for ``preds``, and [0, 2] for ``target``, then class 1 would be - assigned the `absent_score`. - threshold: Threshold value for binary or multi-label probabilities. - - Return: - The shape of the returned tensor depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes - - Example: - >>> from torchmetrics.functional import jaccard_index + Legacy Example: >>> target = torch.randint(0, 2, (10, 25, 25)) >>> pred = torch.tensor(target) >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] - >>> jaccard_index(pred, target, num_classes=2) + >>> jaccard_index(pred, target, task="multiclass", num_classes=2) tensor(0.9660) """ - if task is not None: - if task == "binary": - return binary_jaccard_index(preds, target, threshold, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - return multiclass_jaccard_index(preds, target, num_classes, average, ignore_index, validate_args) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_jaccard_index(preds, target, num_labels, threshold, average, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - confmat = _confusion_matrix_update(preds, target, num_classes, threshold) - return _jaccard_from_confmat(confmat, num_classes, average, ignore_index, absent_score) + if task == "binary": + return binary_jaccard_index(preds, target, threshold, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_jaccard_index(preds, target, num_classes, average, ignore_index, validate_args) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_jaccard_index(preds, target, num_labels, threshold, average, ignore_index, validate_args) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index 07e02a6a4c5..d70510de112 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -22,7 +22,6 @@ _binary_confusion_matrix_format, _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_update, - _confusion_matrix_update, _multiclass_confusion_matrix_arg_validation, _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, @@ -32,7 +31,6 @@ _multilabel_confusion_matrix_tensor_validation, _multilabel_confusion_matrix_update, ) -from torchmetrics.utilities.prints import rank_zero_warn def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: @@ -232,103 +230,38 @@ def multilabel_matthews_corrcoef( return _matthews_corrcoef_reduce(confmat) -_matthews_corrcoef_update = _confusion_matrix_update - - -def _matthews_corrcoef_compute(confmat: Tensor) -> Tensor: - """Computes Matthews correlation coefficient. - - Args: - confmat: Confusion matrix - - Example: - >>> target = torch.tensor([1, 1, 0, 0]) - >>> preds = torch.tensor([0, 1, 0, 0]) - >>> confmat = _matthews_corrcoef_update(preds, target, num_classes=2) - >>> _matthews_corrcoef_compute(confmat) - tensor(0.5774) - """ - - tk = confmat.sum(dim=1).float() - pk = confmat.sum(dim=0).float() - c = torch.trace(confmat).float() - s = confmat.sum().float() - - cov_ytyp = c * s - sum(tk * pk) - cov_ypyp = s**2 - sum(pk * pk) - cov_ytyt = s**2 - sum(tk * tk) - - if cov_ypyp * cov_ytyt == 0: - return torch.tensor(0, dtype=confmat.dtype, device=confmat.device) - else: - return cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp) - - def matthews_corrcoef( preds: Tensor, target: Tensor, - num_classes: int, + task: Literal["binary", "multiclass", "multilabel"] = None, threshold: float = 0.5, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Matthews correlation coefficient. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Calculates `Matthews correlation coefficient`_ that measures - the general correlation or quality of a classification. In the binary case it - is defined as: + r"""Calculates `Matthews correlation coefficient`_ . This metric measures the general correlation or quality of + a classification. - .. math:: - MCC = \frac{TP*TN - FP*FN}{\sqrt{(TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)}} + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_matthews_corrcoef`, :func:`multiclass_matthews_corrcoef` and :func:`multilabel_matthews_corrcoef` for + the specific details of each argument influence and examples. - where TP, TN, FP and FN are respectively the true postitives, true negatives, - false positives and false negatives. Also works in the case of multi-label or - multi-class input. - - Args: - preds: (float or long tensor), Either a ``(N, ...)`` tensor with labels or - ``(N, C, ...)`` where C is the number of classes, tensor with labels/probabilities - target: ``target`` (long tensor), tensor with shape ``(N, ...)`` with ground true labels - num_classes: Number of classes in the dataset. - threshold: - Threshold value for binary or multi-label probabilities. - - Example: - >>> from torchmetrics.functional import matthews_corrcoef + Legacy Example: >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) - >>> matthews_corrcoef(preds, target, num_classes=2) + >>> matthews_corrcoef(preds, target, task="multiclass", num_classes=2) tensor(0.5774) """ - if task is not None: - if task == "binary": - return binary_matthews_corrcoef(preds, target, threshold, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - return multiclass_matthews_corrcoef(preds, target, num_classes, ignore_index, validate_args) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_matthews_corrcoef(preds, target, num_labels, threshold, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - confmat = _matthews_corrcoef_update(preds, target, num_classes, threshold) - return _matthews_corrcoef_compute(confmat) + if task == "binary": + return binary_matthews_corrcoef(preds, target, threshold, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_matthews_corrcoef(preds, target, num_classes, ignore_index, validate_args) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_matthews_corrcoef(preds, target, num_labels, threshold, ignore_index, validate_args) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index 3f6fc6f3e0a..1c2f20e0e22 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional import torch from torch import Tensor @@ -30,12 +30,8 @@ _multilabel_stat_scores_format, _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, - _reduce_stat_scores, - _stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide -from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod -from torchmetrics.utilities.prints import rank_zero_warn def _precision_recall_reduce( @@ -656,601 +652,105 @@ def multilabel_recall( return _precision_recall_reduce("recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average) -def _precision_compute( - tp: Tensor, - fp: Tensor, - fn: Tensor, - average: Optional[str], - mdmc_average: Optional[str], -) -> Tensor: - """Computes precision from the stat scores: true positives, false positives, true negatives, false negatives. - - Args: - tp: True positives - fp: False positives - fn: False negatives - average: Defines the reduction that is applied - mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter) - - Example: - >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> tp, fp, tn, fn = _stat_scores_update( preds, target, reduce='macro', num_classes=3) - >>> _precision_compute(tp, fp, fn, average='macro', mdmc_average=None) - tensor(0.1667) - >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='micro') - >>> _precision_compute(tp, fp, fn, average='micro', mdmc_average=None) - tensor(0.2500) - """ - - numerator = tp.clone() - denominator = tp + fp - - if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: - cond = tp + fp + fn == 0 - numerator = numerator[~cond] - denominator = denominator[~cond] - - if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: - # a class is not present if there exists no TPs, no FPs, and no FNs - meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu() - numerator[meaningless_indeces, ...] = -1 - denominator[meaningless_indeces, ...] = -1 - - return _reduce_stat_scores( - numerator=numerator, - denominator=denominator, - weights=None if average != "weighted" else tp + fn, - average=average, - mdmc_average=mdmc_average, - ) - - def precision( preds: Tensor, target: Tensor, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - num_classes: Optional[int] = None, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Precision. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes `Precision`_ + r"""Computes `Precision`_: .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and - false positives respecitively. With the use of ``top_k`` parameter, this metric can - generalize to Precision@K. - - The reduction method (how the precision scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. - - Args: - preds: Predictions from model (probabilities, logits or labels) - target: Ground truth values - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, - the value for the class will be ``nan``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. + false positives respecitively. - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_precision`, :func:`multiclass_precision` and :func:`multilabel_precision` for the specific details of + each argument influence and examples. - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - top_k: - Number of highest probability or logit score predictions considered to find the correct label, - relevant only for (multi-dimensional) multi-class inputs. The - default value (``None``) will be interpreted as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - Return: - The shape of the returned tensor depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number of classes - - Raises: - ValueError: - If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"`` or ``None`` - ValueError: - If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. - ValueError: - If ``average`` is set but ``num_classes`` is not provided. - ValueError: - If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. - - Example: - >>> from torchmetrics.functional import precision + Legacy Example: >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) - >>> precision(preds, target, average='macro', num_classes=3) + >>> precision(preds, target, task="multiclass", average='macro', num_classes=3) tensor(0.1667) - >>> precision(preds, target, average='micro') + >>> precision(preds, target, task="multiclass", average='micro', num_classes=3) tensor(0.2500) """ - if task is not None: - assert multidim_average is not None - if task == "binary": - return binary_precision(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return multiclass_precision( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args - ) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_precision( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args - ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + assert multidim_average is not None + if task == "binary": + return binary_precision(preds, target, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_precision( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_precision( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - allowed_average = ["micro", "macro", "weighted", "samples", "none", None] - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - allowed_mdmc_average = [None, "samplewise", "global"] - if mdmc_average not in allowed_mdmc_average: - raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") - - if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): - raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - reduce = "macro" if average in ["weighted", "none", None] else average - tp, fp, _, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce=mdmc_average, - threshold=threshold, - num_classes=num_classes, - top_k=top_k, - multiclass=multiclass, - ignore_index=ignore_index, - ) - - return _precision_compute(tp, fp, fn, average, mdmc_average) - - -def _recall_compute( - tp: Tensor, - fp: Tensor, - fn: Tensor, - average: Optional[str], - mdmc_average: Optional[str], -) -> Tensor: - """Computes precision from the stat scores: true positives, false positives, true negatives, false negatives. - - Args: - tp: True positives - fp: False positives - fn: False negatives - average: Defines the reduction that is applied - mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter) - - Example: - >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='macro', num_classes=3) - >>> _recall_compute(tp, fp, fn, average='macro', mdmc_average=None) - tensor(0.3333) - >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='micro') - >>> _recall_compute(tp, fp, fn, average='micro', mdmc_average=None) - tensor(0.2500) - """ - numerator = tp.clone() - denominator = tp + fn - - if average == AverageMethod.MACRO and mdmc_average != MDMCAverageMethod.SAMPLEWISE: - cond = tp + fp + fn == 0 - numerator = numerator[~cond] - denominator = denominator[~cond] - - if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: - # a class is not present if there exists no TPs, no FPs, and no FNs - meaningless_indeces = ((tp | fn | fp) == 0).nonzero().cpu() - numerator[meaningless_indeces, ...] = -1 - denominator[meaningless_indeces, ...] = -1 - - return _reduce_stat_scores( - numerator=numerator, - denominator=denominator, - weights=None if average != AverageMethod.WEIGHTED else tp + fn, - average=average, - mdmc_average=mdmc_average, + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) def recall( preds: Tensor, target: Tensor, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - num_classes: Optional[int] = None, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Recall. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes `Recall`_ + r"""Computes `Recall`_: .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and - false negatives respecitively. With the use of ``top_k`` parameter, this metric can - generalize to Recall@K. - - The reduction method (how the recall scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. - - Args: - preds: Predictions from model (probabilities, logits or labels) - target: Ground truth values - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, - the value for the class will be ``nan``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. + false negatives respecitively. - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_recall`, :func:`multiclass_recall` and :func:`multilabel_recall` for the specific details of + each argument influence and examples. - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - top_k: - Number of the highest probability or logit score predictions considered finding the correct label, - relevant only for (multi-dimensional) multi-class inputs. The - default value (``None``) will be interpreted as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - Return: - The shape of the returned tensor depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number of classes - - Raises: - ValueError: - If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"`` or ``None`` - ValueError: - If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. - ValueError: - If ``average`` is set but ``num_classes`` is not provided. - ValueError: - If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. - - Example: - >>> from torchmetrics.functional import recall + Legacy Example: >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) - >>> recall(preds, target, average='macro', num_classes=3) + >>> recall(preds, target, task="multiclass", average='macro', num_classes=3) tensor(0.3333) - >>> recall(preds, target, average='micro') + >>> recall(preds, target, task="multiclass", average='micro', num_classes=3) tensor(0.2500) """ - if task is not None: - assert multidim_average is not None - if task == "binary": - return binary_recall(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return multiclass_recall( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args - ) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_recall( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args - ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + assert multidim_average is not None + if task == "binary": + return binary_recall(preds, target, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_recall( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_recall( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - allowed_average = ("micro", "macro", "weighted", "samples", "none", None) - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - allowed_mdmc_average = (None, "samplewise", "global") - if mdmc_average not in allowed_mdmc_average: - raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") - - if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): - raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - reduce = "macro" if average in ["weighted", "none", None] else average - tp, fp, _, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce=mdmc_average, - threshold=threshold, - num_classes=num_classes, - top_k=top_k, - multiclass=multiclass, - ignore_index=ignore_index, - ) - - return _recall_compute(tp, fp, fn, average, mdmc_average) - - -def precision_recall( - preds: Tensor, - target: Tensor, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - num_classes: Optional[int] = None, - threshold: float = 0.5, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, -) -> Tuple[Tensor, Tensor]: - r"""Computes `Precision`_ - - .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} - - .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} - - Where :math:`\text{TP}`m :math:`\text{FN}` and :math:`\text{FP}` represent the number - of true positives, false negatives and false positives respecitively. With the use of - ``top_k`` parameter, this metric can generalize to Recall@K and Precision@K. - - The reduction method (how the recall scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. - - Args: - preds: Predictions from model (probabilities, logits or labels) - target: Ground truth values - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tp + fn``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, - the value for the class will be ``nan``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. - - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. - - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - top_k: - Number of highest probability or logit score predictions considered to find the correct label, - relevant only for (multi-dimensional) multi-class inputs. The - default value (``None``) will be interpreted as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - Return: - The function returns a tuple with two elements: precision and recall. Their shape - depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, they are a single element tensor - - If ``average in ['none', None]``, they are a tensor of shape ``(C, )``, where ``C`` stands for - the number of classes - - Raises: - ValueError: - If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"`` or ``None`` - ValueError: - If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. - ValueError: - If ``average`` is set but ``num_classes`` is not provided. - ValueError: - If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. - - Example: - >>> from torchmetrics.functional import precision_recall - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> precision_recall(preds, target, average='macro', num_classes=3) - (tensor(0.1667), tensor(0.3333)) - >>> precision_recall(preds, target, average='micro') - (tensor(0.2500), tensor(0.2500)) - """ - allowed_average = ("micro", "macro", "weighted", "samples", "none", None) - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - allowed_mdmc_average = (None, "samplewise", "global") - if mdmc_average not in allowed_mdmc_average: - raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") - - if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): - raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - reduce = "macro" if average in ["weighted", "none", None] else average - tp, fp, _, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce=mdmc_average, - threshold=threshold, - num_classes=num_classes, - top_k=top_k, - multiclass=multiclass, - ignore_index=ignore_index, + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - - precision_ = _precision_compute(tp, fp, fn, average, mdmc_average) - recall_ = _recall_compute(tp, fp, fn, average, mdmc_average) - - return precision_, recall_ diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index a505898f040..e69d46275e9 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -19,7 +19,6 @@ from torch.nn import functional as F from typing_extensions import Literal -from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.data import _bincount @@ -769,278 +768,41 @@ def multilabel_precision_recall_curve( return _multilabel_precision_recall_curve_compute(state, num_labels, thresholds, ignore_index) -def _precision_recall_curve_update( - preds: Tensor, - target: Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, -) -> Tuple[Tensor, Tensor, int, Optional[int]]: - """Updates and returns variables required to compute the precision-recall pairs for different thresholds. - - Args: - preds: Predicted tensor. - target: Ground truth tensor. - num_classes: integer with number of classes for multi-label and multiclass problems. - Should be set to ``None`` for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problems is translated to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1]. - """ - - if len(preds.shape) == len(target.shape): - if pos_label is None: - pos_label = 1 - if num_classes is not None and num_classes != 1: - # multilabel problem - if num_classes != preds.shape[1]: - raise ValueError( - f"Argument `num_classes` was set to {num_classes} in" - f" metric `precision_recall_curve` but detected {preds.shape[1]}" - " number of classes from predictions" - ) - preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1) - target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1) - else: - # binary problem - preds = preds.flatten() - target = target.flatten() - num_classes = 1 - - # multi class problem - elif len(preds.shape) == len(target.shape) + 1: - if pos_label is not None: - rank_zero_warn( - "Argument `pos_label` should be `None` when running" - f" multiclass precision recall curve. Got {pos_label}" - ) - if num_classes != preds.shape[1]: - raise ValueError( - f"Argument `num_classes` was set to {num_classes} in" - f" metric `precision_recall_curve` but detected {preds.shape[1]}" - " number of classes from predictions" - ) - preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1) - target = target.flatten() - - else: - raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") - - return preds, target, num_classes, pos_label - - -def _precision_recall_curve_compute_single_class( - preds: Tensor, - target: Tensor, - pos_label: int, - sample_weights: Optional[Sequence] = None, -) -> Tuple[Tensor, Tensor, Tensor]: - """Computes precision-recall pairs for single class inputs. - - Args: - preds: Predicted tensor. - target: Ground truth tensor. - pos_label: integer determining the positive class. - sample_weights: sample weights for each data point. - """ - - fps, tps, thresholds = _binary_clf_curve( - preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label - ) - precision = tps / (tps + fps) - recall = tps / tps[-1] - - # stop when full recall attained and reverse the outputs so recall is decreasing - last_ind = torch.where(tps == tps[-1])[0][0] - sl = slice(0, last_ind.item() + 1) - - # need to call reversed explicitly, since including that to slice would - # introduce negative strides that are not yet supported in pytorch - precision = torch.cat([reversed(precision[sl]), torch.ones(1, dtype=precision.dtype, device=precision.device)]) - - recall = torch.cat([reversed(recall[sl]), torch.zeros(1, dtype=recall.dtype, device=recall.device)]) - - thresholds = reversed(thresholds[sl]).detach().clone() # type: ignore - - return precision, recall, thresholds - - -def _precision_recall_curve_compute_multi_class( - preds: Tensor, - target: Tensor, - num_classes: int, - sample_weights: Optional[Sequence] = None, -) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: - """Computes precision-recall pairs for multiclass inputs. - - Args: - preds: Predicted tensor. - target: Ground truth tensor. - num_classes: integer with number of classes for multi-label and multiclass problems. - Should be set to ``None`` for binary problems. - sample_weights: sample weights for each data point. - """ - - # Recursively call per class - precision, recall, thresholds = [], [], [] - for cls in range(num_classes): - preds_cls = preds[:, cls] - - prc_args = dict( - preds=preds_cls, - target=target, - num_classes=1, - pos_label=cls, - sample_weights=sample_weights, - ) - if target.ndim > 1: - prc_args.update( - dict( - target=target[:, cls], - pos_label=1, - ) - ) - res = precision_recall_curve(**prc_args) - precision.append(res[0]) - recall.append(res[1]) - thresholds.append(res[2]) - - return precision, recall, thresholds - - -def _precision_recall_curve_compute( - preds: Tensor, - target: Tensor, - num_classes: int, - pos_label: Optional[int] = None, - sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - """Computes precision-recall pairs based on the number of classes. - - Args: - preds: Predicted tensor. - target: Ground truth tensor. - num_classes: integer with number of classes for multi-label and multiclass problems. - Should be set to ``None`` for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problems is translated to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range ``[0,num_classes-1]``. - sample_weights: sample weights for each data point. - - Example: - >>> # binary case - >>> preds = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 0]) - >>> pos_label = 1 - >>> preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, pos_label=pos_label) - >>> precision, recall, thresholds = _precision_recall_curve_compute(preds, target, num_classes, pos_label) - >>> precision - tensor([0.6667, 0.5000, 0.0000, 1.0000]) - >>> recall - tensor([1.0000, 0.5000, 0.0000, 0.0000]) - >>> thresholds - tensor([1, 2, 3]) - - >>> # multiclass case - >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.05, 0.75, 0.05]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> num_classes = 5 - >>> preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes) - >>> precision, recall, thresholds = _precision_recall_curve_compute(preds, target, num_classes) - >>> precision - [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), - tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] - >>> recall - [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] - >>> thresholds - [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] - """ - - with torch.no_grad(): - if num_classes == 1: - if pos_label is None: - pos_label = 1 - return _precision_recall_curve_compute_single_class(preds, target, pos_label, sample_weights) - return _precision_recall_curve_compute_multi_class(preds, target, num_classes, sample_weights) - - def precision_recall_curve( preds: Tensor, target: Tensor, + task: Literal["binary", "multiclass", "multilabel"], + thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - sample_weights: Optional[Sequence] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, num_labels: Optional[int] = None, - thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - r"""Precision-recall. + r"""Computes the precision-recall curve. The curve consist of multiple pairs of precision and recall values + evaluated at different thresholds, such that the tradeoff between the two values can been seen. - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_precision_recall_curve`, :func:`multiclass_precision_recall_curve` and + :func:`multilabel_precision_recall_curve` for the specific details of each argument influence and examples. - Computes precision-recall pairs for different thresholds. - - Args: - preds: predictions from model (probabilities). - target: ground truth labels. - num_classes: integer with number of classes for multi-label and multiclass problems. - Should be set to ``None`` for binary problems. - pos_label: integer determining the positive class. Default is ``None`` which for binary problem is translated - to 1. For multiclass problems this argument should not be set as we iteratively change it in the - range ``[0, num_classes-1]``. - sample_weights: sample weights for each data point. - - Returns: - 3-element tuple containing - - precision: - tensor where element ``i`` is the precision of predictions with - ``score >= thresholds[i]`` and the last element is 1. - If multiclass, this is a list of such tensors, one for each class. - recall: - tensor where element ``i`` is the recall of predictions with - ``score >= thresholds[i]`` and the last element is 0. - If multiclass, this is a list of such tensors, one for each class. - thresholds: - Thresholds used for computing precision/recall scores. - - Raises: - ValueError: - If ``preds`` and ``target`` don't have the same number of dimensions, - or one additional dimension for ``preds``. - ValueError: - If the number of classes deduced from ``preds`` is not the same as the ``num_classes`` provided. - - Example (binary case): - >>> from torchmetrics.functional import precision_recall_curve - >>> pred = torch.tensor([0, 1, 2, 3]) + Legacy Example: + >>> pred = torch.tensor([0.0, 1.0, 2.0, 3.0]) >>> target = torch.tensor([0, 1, 1, 0]) - >>> precision, recall, thresholds = precision_recall_curve(pred, target, pos_label=1) + >>> precision, recall, thresholds = precision_recall_curve(pred, target, task='binary') >>> precision tensor([0.6667, 0.5000, 0.0000, 1.0000]) >>> recall tensor([1.0000, 0.5000, 0.0000, 0.0000]) >>> thresholds - tensor([1, 2, 3]) + tensor([0.7311, 0.8808, 0.9526]) - Example (multiclass case): >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> precision, recall, thresholds = precision_recall_curve(pred, target, num_classes=5) + >>> precision, recall, thresholds = precision_recall_curve(pred, target, task='multiclass', num_classes=5) >>> precision [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] @@ -1049,28 +811,14 @@ def precision_recall_curve( >>> thresholds [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] """ - if task is not None: - if task == "binary": - return binary_precision_recall_curve(preds, target, thresholds, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - return multiclass_precision_recall_curve( - preds, target, num_classes, thresholds, ignore_index, validate_args - ) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_precision_recall_curve(preds, target, num_labels, thresholds, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) - return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights) + if task == "binary": + return binary_precision_recall_curve(preds, target, thresholds, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_precision_recall_curve(preds, target, num_classes, thresholds, ignore_index, validate_args) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_precision_recall_curve(preds, target, num_labels, thresholds, ignore_index, validate_args) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/functional/classification/recall_at_fixed_precision.py b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py index 967fdb1dc03..f42176e4492 100644 --- a/src/torchmetrics/functional/classification/recall_at_fixed_precision.py +++ b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.classification.precision_recall_curve import ( _binary_precision_recall_curve_arg_validation, @@ -86,9 +87,9 @@ def binary_recall_at_fixed_precision( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tuple[Tensor, Tensor]: - r"""Computes the higest possible recall value given the minimum precision thresholds provided. This is done by - first calculating the precision-recall curve for different thresholds and the find the recall for a given - precision level. + r"""Computes the higest possible recall value given the minimum precision thresholds provided for binary tasks. + This is done by first calculating the precision-recall curve for different thresholds and the find the recall + for a given precision level. Accepts the following input tensors: @@ -185,9 +186,9 @@ def multiclass_recall_at_fixed_precision( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tuple[Tensor, Tensor]: - r"""Computes the higest possible recall value given the minimum precision thresholds provided. This is done by - first calculating the precision-recall curve for different thresholds and the find the recall for a given - precision level. + r"""Computes the higest possible recall value given the minimum precision thresholds provided for multiclass + tasks. This is done by first calculating the precision-recall curve for different thresholds and the find the + recall for a given precision level. Accepts the following input tensors: @@ -293,9 +294,9 @@ def multilabel_recall_at_fixed_precision( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tuple[Tensor, Tensor]: - r"""Computes the higest possible recall value given the minimum precision thresholds provided. This is done by - first calculating the precision-recall curve for different thresholds and the find the recall for a given - precision level. + r"""Computes the higest possible recall value given the minimum precision thresholds provided for multilabel + tasks. This is done by first calculating the precision-recall curve for different thresholds and the find the + recall for a given precision level. Accepts the following input tensors: @@ -361,3 +362,40 @@ def multilabel_recall_at_fixed_precision( ) state = _multilabel_precision_recall_curve_update(preds, target, num_labels, thresholds) return _multilabel_recall_at_fixed_precision_arg_compute(state, num_labels, thresholds, ignore_index, min_precision) + + +def recall_at_fixed_precision( + preds: Tensor, + target: Tensor, + task: Literal["binary", "multiclass", "multilabel"], + min_precision: float, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + r"""Computes the higest possible recall value given the minimum precision thresholds provided. This is done by + first calculating the precision-recall curve for different thresholds and the find the recall for a given + precision level. + + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_recall_at_fixed_precision`, :func:`multiclass_recall_at_fixed_precision` and + :func:`multilabel_recall_at_fixed_precision` for the specific details of each argument influence and examples. + """ + if task == "binary": + return binary_recall_at_fixed_precision(preds, target, min_precision, thresholds, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_recall_at_fixed_precision( + preds, target, num_classes, min_precision, thresholds, ignore_index, validate_args + ) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_recall_at_fixed_precision( + preds, target, num_labels, min_precision, thresholds, ignore_index, validate_args + ) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index c29e234a8c7..0726069e67b 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Sequence, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import Tensor @@ -31,7 +31,6 @@ _multilabel_precision_recall_curve_format, _multilabel_precision_recall_curve_tensor_validation, _multilabel_precision_recall_curve_update, - _precision_recall_curve_update, ) from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.compute import _safe_divide @@ -186,9 +185,9 @@ def multiclass_roc( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - r"""Computes the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple pairs - of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that - the tradeoff between the two values can be seen. + r"""Computes the Receiver Operating Characteristic (ROC) for multiclass tasks. The curve consist of multiple + pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such + that the tradeoff between the two values can be seen. Accepts the following input tensors: @@ -321,9 +320,9 @@ def multilabel_roc( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - r"""Computes the Receiver Operating Characteristic (ROC) for binary tasks. The curve consist of multiple pairs - of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that - the tradeoff between the two values can be seen. + r"""Computes the Receiver Operating Characteristic (ROC) for multilabel tasks. The curve consist of multiple + pairs of true positive rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such + that the tradeoff between the two values can be seen. Accepts the following input tensors: @@ -421,264 +420,58 @@ def multilabel_roc( return _multilabel_roc_compute(state, num_labels, thresholds, ignore_index) -def _roc_update( - preds: Tensor, - target: Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, -) -> Tuple[Tensor, Tensor, int, Optional[int]]: - """Updates and returns variables required to compute the Receiver Operating Characteristic. - - Args: - preds: Predicted tensor - target: Ground truth tensor - num_classes: integer with number of classes for multi-label and multiclass problems. - Should be set to ``None`` for binary problems. - pos_label: integer determining the positive class. Default is ``None`` - which for binary problem is translated to 1. For multiclass problems - this argument should not be set as we iteratively change it in the - range [0,num_classes-1] - """ - - return _precision_recall_curve_update(preds, target, num_classes, pos_label) - - -def _roc_compute_single_class( - preds: Tensor, - target: Tensor, - pos_label: int, - sample_weights: Optional[Sequence] = None, -) -> Tuple[Tensor, Tensor, Tensor]: - """Computes Receiver Operating Characteristic for single class inputs. Returns tensor with false positive - rates, tensor with true positive rates, tensor with thresholds used for computing false- and true-postive - rates. - - Args: - preds: Predicted tensor - target: Ground truth tensor - pos_label: integer determining the positive class. Default is ``None`` which for binary problem is translated - to 1. For multiclass problems this argument should not be set as we iteratively change it in the - range ``[0, num_classes-1]`` - sample_weights: sample weights for each data point - """ - - fps, tps, thresholds = _binary_clf_curve( - preds=preds, target=target, sample_weights=sample_weights, pos_label=pos_label - ) - # Add an extra threshold position to make sure that the curve starts at (0, 0) - tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) - fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) - thresholds = torch.cat([thresholds[0][None] + 1, thresholds]) - - if fps[-1] <= 0: - rank_zero_warn( - "No negative samples in targets, false positive value should be meaningless." - " Returning zero tensor in false positive score", - UserWarning, - ) - fpr = torch.zeros_like(thresholds) - else: - fpr = fps / fps[-1] - - if tps[-1] <= 0: - rank_zero_warn( - "No positive samples in targets, true positive value should be meaningless." - " Returning zero tensor in true positive score", - UserWarning, - ) - tpr = torch.zeros_like(thresholds) - else: - tpr = tps / tps[-1] - - return fpr, tpr, thresholds - - -def _roc_compute_multi_class( - preds: Tensor, - target: Tensor, - num_classes: int, - sample_weights: Optional[Sequence] = None, -) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: - """Computes Receiver Operating Characteristic for multi class inputs. Returns tensor with false positive rates, - tensor with true positive rates, tensor with thresholds used for computing false- and true-postive rates. - - Args: - preds: Predicted tensor - target: Ground truth tensor - num_classes: number of classes - sample_weights: sample weights for each data point - """ - - fpr, tpr, thresholds = [], [], [] - for cls in range(num_classes): - if preds.shape == target.shape: - target_cls = target[:, cls] - pos_label = 1 - else: - target_cls = target - pos_label = cls - res = roc( - preds=preds[:, cls], - target=target_cls, - num_classes=1, - pos_label=pos_label, - sample_weights=sample_weights, - ) - fpr.append(res[0]) - tpr.append(res[1]) - thresholds.append(res[2]) - - return fpr, tpr, thresholds - - -def _roc_compute( - preds: Tensor, - target: Tensor, - num_classes: int, - pos_label: Optional[int] = None, - sample_weights: Optional[Sequence] = None, -) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - """Computes Receiver Operating Characteristic based on the number of classes. - - Args: - preds: Predicted tensor - target: Ground truth tensor - num_classes: integer with number of classes for multi-label and multiclass problems. - Should be set to ``None`` for binary problems. - pos_label: integer determining the positive class. Default is ``None`` which for binary problem is translated - to 1. For multiclass problems this argument should not be set as we iteratively change it in the - range ``[0, num_classes-1]`` - sample_weights: sample weights for each data point - - Example: - >>> # binary case - >>> preds = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 1, 1]) - >>> pos_label = 1 - >>> preds, target, num_classes, pos_label = _roc_update(preds, target, pos_label=pos_label) - >>> fpr, tpr, thresholds = _roc_compute(preds, target, num_classes, pos_label) - >>> fpr - tensor([0., 0., 0., 0., 1.]) - >>> tpr - tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) - >>> thresholds - tensor([4, 3, 2, 1, 0]) - - >>> # multiclass case - >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05], - ... [0.05, 0.75, 0.05, 0.05], - ... [0.05, 0.05, 0.75, 0.05], - ... [0.05, 0.05, 0.05, 0.75]]) - >>> target = torch.tensor([0, 1, 3, 2]) - >>> num_classes = 4 - >>> preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes) - >>> fpr, tpr, thresholds = _roc_compute(preds, target, num_classes) - >>> fpr - [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] - >>> tpr - [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] - >>> thresholds - [tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500])] - """ - - with torch.no_grad(): - if num_classes == 1 and preds.ndim == 1: # binary - if pos_label is None: - pos_label = 1 - return _roc_compute_single_class(preds, target, pos_label, sample_weights) - return _roc_compute_multi_class(preds, target, num_classes, sample_weights) - - def roc( preds: Tensor, target: Tensor, - num_classes: Optional[int] = None, - pos_label: Optional[int] = None, - sample_weights: Optional[Sequence] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + task: Literal["binary", "multiclass", "multilabel"], thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - r"""Receiver Operating Characteristic. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. + r"""Computes the Receiver Operating Characteristic (ROC). The curve consist of multiple pairs of true positive + rate (TPR) and false positive rate (FPR) values evaluated at different thresholds, such that the tradeoff + between the two values can be seen. - Computes the Receiver Operating Characteristic (ROC). Works with both binary, multiclass and multilabel - input. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_roc`, :func:`multiclass_roc` and :func:`multilabel_roc` for the specific details of each argument + influence and examples. - .. note:: - If either the positive class or negative class is completly missing in the target tensor, - the roc values are not well-defined in this case and a tensor of zeros will be returned (either fpr - or tpr depending on what class is missing) together with a warning. - - Args: - preds: predictions from model (logits or probabilities) - target: ground truth values - num_classes: integer with number of classes for multi-label and multiclass problems. - Should be set to ``None`` for binary problems. - pos_label: integer determining the positive class. Default is ``None`` which for binary problem is translated - to 1. For multiclass problems this argument should not be set as we iteratively change it in the - range ``[0, num_classes-1]`` - sample_weights: sample weights for each data point - - Returns: - 3-element tuple containing - - fpr: tensor with false positive rates. - If multiclass or multilabel, this is a list of such tensors, one for each class/label. - tpr: tensor with true positive rates. - If multiclass or multilabel, this is a list of such tensors, one for each class/label. - thresholds: tensor with thresholds used for computing false- and true postive rates - If multiclass or multilabel, this is a list of such tensors, one for each class/label. - - Example (binary case): - >>> from torchmetrics.functional import roc - >>> pred = torch.tensor([0, 1, 2, 3]) + Legacy Example: + >>> pred = torch.tensor([0.0, 1.0, 2.0, 3.0]) >>> target = torch.tensor([0, 1, 1, 1]) - >>> fpr, tpr, thresholds = roc(pred, target, pos_label=1) + >>> fpr, tpr, thresholds = roc(pred, target, task='binary') >>> fpr tensor([0., 0., 0., 0., 1.]) >>> tpr tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) >>> thresholds - tensor([4, 3, 2, 1, 0]) + tensor([1.0000, 0.9526, 0.8808, 0.7311, 0.5000]) - Example (multiclass case): - >>> from torchmetrics.functional import roc >>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05], ... [0.05, 0.05, 0.05, 0.75]]) >>> target = torch.tensor([0, 1, 3, 2]) - >>> fpr, tpr, thresholds = roc(pred, target, num_classes=4) + >>> fpr, tpr, thresholds = roc(pred, target, task='multiclass', num_classes=4) >>> fpr [tensor([0., 0., 1.]), tensor([0., 0., 1.]), tensor([0.0000, 0.3333, 1.0000]), tensor([0.0000, 0.3333, 1.0000])] >>> tpr [tensor([0., 1., 1.]), tensor([0., 1., 1.]), tensor([0., 0., 1.]), tensor([0., 0., 1.])] >>> thresholds - [tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500]), - tensor([1.7500, 0.7500, 0.0500])] + [tensor([1.0000, 0.7500, 0.0500]), + tensor([1.0000, 0.7500, 0.0500]), + tensor([1.0000, 0.7500, 0.0500]), + tensor([1.0000, 0.7500, 0.0500])] - Example (multilabel case): - >>> from torchmetrics.functional import roc >>> pred = torch.tensor([[0.8191, 0.3680, 0.1138], ... [0.3584, 0.7576, 0.1183], ... [0.2286, 0.3468, 0.1338], ... [0.8603, 0.0745, 0.1837]]) >>> target = torch.tensor([[1, 1, 0], [0, 1, 0], [0, 0, 0], [0, 1, 1]]) - >>> fpr, tpr, thresholds = roc(pred, target, num_classes=3, pos_label=1) + >>> fpr, tpr, thresholds = roc(pred, target, task='multilabel', num_labels=3) >>> fpr [tensor([0.0000, 0.3333, 0.3333, 0.6667, 1.0000]), tensor([0., 0., 0., 1., 1.]), @@ -686,30 +479,18 @@ def roc( >>> tpr [tensor([0., 0., 1., 1., 1.]), tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 1., 1., 1., 1.])] >>> thresholds - [tensor([1.8603, 0.8603, 0.8191, 0.3584, 0.2286]), - tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]), - tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])] + [tensor([1.0000, 0.8603, 0.8191, 0.3584, 0.2286]), + tensor([1.0000, 0.7576, 0.3680, 0.3468, 0.0745]), + tensor([1.0000, 0.1837, 0.1338, 0.1183, 0.1138])] """ - if task is not None: - if task == "binary": - return binary_roc(preds, target, thresholds, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - return multiclass_roc(preds, target, num_classes, thresholds, ignore_index, validate_args) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, - ) - preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes, pos_label) - return _roc_compute(preds, target, num_classes, pos_label, sample_weights) + if task == "binary": + return binary_roc(preds, target, thresholds, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + return multiclass_roc(preds, target, num_classes, thresholds, ignore_index, validate_args) + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) diff --git a/src/torchmetrics/functional/classification/specificity.py b/src/torchmetrics/functional/classification/specificity.py index 93577a67c0d..b9be98e100e 100644 --- a/src/torchmetrics/functional/classification/specificity.py +++ b/src/torchmetrics/functional/classification/specificity.py @@ -30,12 +30,8 @@ _multilabel_stat_scores_format, _multilabel_stat_scores_tensor_validation, _multilabel_stat_scores_update, - _reduce_stat_scores, - _stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide -from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod -from torchmetrics.utilities.prints import rank_zero_warn def _specificity_reduce( @@ -358,225 +354,53 @@ def multilabel_specificity( return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average) -def _specificity_compute( - tp: Tensor, - fp: Tensor, - tn: Tensor, - fn: Tensor, - average: Optional[str], - mdmc_average: Optional[str], -) -> Tensor: - """Computes specificity from the stat scores: true positives, false positives, true negatives, false negatives. - - Args: - tp: True positives - fp: False positives - tn: True negatives - fn: False negatives - average: Defines the reduction that is applied - mdmc_average: Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter) - - Example: - >>> from torchmetrics.functional.classification.stat_scores import _stat_scores_update - >>> preds = torch.tensor([2, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='macro', num_classes=3) - >>> _specificity_compute(tp, fp, tn, fn, average='macro', mdmc_average=None) - tensor(0.6111) - >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='micro') - >>> _specificity_compute(tp, fp, tn, fn, average='micro', mdmc_average=None) - tensor(0.6250) - """ - - numerator = tn.clone() - denominator = tn + fp - if average == AverageMethod.NONE and mdmc_average != MDMCAverageMethod.SAMPLEWISE: - # a class is not present if there exists no TPs, no FPs, and no FNs - meaningless_indeces = torch.nonzero((tp | fn | fp) == 0).cpu() - numerator[meaningless_indeces, ...] = -1 - denominator[meaningless_indeces, ...] = -1 - return _reduce_stat_scores( - numerator=numerator, - denominator=denominator, - weights=None if average != AverageMethod.WEIGHTED else denominator, - average=average, - mdmc_average=mdmc_average, - ) - - def specificity( preds: Tensor, target: Tensor, - average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", - mdmc_average: Optional[str] = None, - ignore_index: Optional[int] = None, - num_classes: Optional[int] = None, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - top_k: Optional[int] = None, - multiclass: Optional[bool] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Specificity. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - Computes `Specificity`_ + r"""Computes `Specificity`_. .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and - false positives respecitively. With the use of ``top_k`` parameter, this metric can - generalize to Specificity@K. - - The reduction method (how the specificity scores are aggregated) is controlled by the - ``average`` parameter, and additionally by the ``mdmc_average`` parameter in the - multi-dimensional multi-class case. - - Args: - preds: Predictions from model (probabilities, or labels) - target: Ground truth values - average: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Calculate the metric globally, across all samples and classes. - - ``'macro'``: Calculate the metric for each class separately, and average the - metrics across classes (with equal weights for each class). - - ``'weighted'``: Calculate the metric for each class separately, and average the - metrics across classes, weighting each class by its support (``tn + fp``). - - ``'none'`` or ``None``: Calculate the metric for each class separately, and return - the metric for every class. - - ``'samples'``: Calculate the metric for each sample, and average the metrics - across samples (with equal weights for each sample). - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_average``. - - .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, - the value for the class will be ``nan``. - - mdmc_average: - Defines how averaging is done for multi-dimensional multi-class inputs (on top of the - ``average`` parameter). Should be one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then averaged over samples. - The computation for each sample is done by treating the flattened extra axes ``...`` - as the ``N`` dimension within the sample, - and computing the metric for the sample based on that. + false positives respecitively. - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs - are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``average`` parameter applies as usual. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_specificity`, :func:`multiclass_specificity` and :func:`multilabel_specificity` for the specific + details of each argument influence and examples. - ignore_index: - Integer specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and ``average=None`` - or ``'none'``, the score for the ignored class will be returned as ``nan``. - num_classes: - Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods. - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. If this parameter is set for multi-label - inputs, it will take precedence over ``threshold``. For (multi-dim) multi-class inputs, - this parameter defaults to 1. - - Should be left unset (``None``) for inputs with label predictions. - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - Return: - The shape of the returned tensor depends on the ``average`` parameter - - - If ``average in ['micro', 'macro', 'weighted', 'samples']``, a one-element tensor will be returned - - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number - of classes - - Raises: - ValueError: - If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``, ``"samples"``, ``"none"`` or ``None`` - ValueError: - If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``. - ValueError: - If ``average`` is set but ``num_classes`` is not provided. - ValueError: - If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. - - Example: - >>> from torchmetrics.functional import specificity + LegacyExample: >>> preds = torch.tensor([2, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) - >>> specificity(preds, target, average='macro', num_classes=3) + >>> specificity(preds, target, task="multiclass", average='macro', num_classes=3) tensor(0.6111) - >>> specificity(preds, target, average='micro') + >>> specificity(preds, target, task="multiclass", average='micro', num_classes=3) tensor(0.6250) """ - if task is not None: - assert multidim_average is not None - if task == "binary": - return binary_specificity(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return multiclass_specificity( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args - ) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_specificity( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args - ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + assert multidim_average is not None + if task == "binary": + return binary_specificity(preds, target, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_specificity( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_specificity( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - allowed_average = ("micro", "macro", "weighted", "samples", "none", None) - if average not in allowed_average: - raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - - allowed_mdmc_average = (None, "samplewise", "global") - if mdmc_average not in allowed_mdmc_average: - raise ValueError(f"The `mdmc_average` has to be one of {allowed_mdmc_average}, got {mdmc_average}.") - - if average in ["macro", "weighted", "none", None] and (not num_classes or num_classes < 1): - raise ValueError(f"When you set `average` as {average}, you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - reduce = "macro" if average in ["weighted", "none", None] else average - tp, fp, tn, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce=mdmc_average, - threshold=threshold, - num_classes=num_classes, - top_k=top_k, - multiclass=multiclass, - ignore_index=ignore_index, + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - - return _specificity_compute(tp, fp, tn, fn, average, mdmc_average) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 22d8de769e4..6571aaccbac 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -20,7 +20,6 @@ from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification from torchmetrics.utilities.data import _bincount, select_topk from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod -from torchmetrics.utilities.prints import rank_zero_warn def _binary_stat_scores_arg_validation( @@ -913,7 +912,7 @@ def _stat_scores_update( reduce: Optional[str] = "micro", mdmc_reduce: Optional[str] = None, num_classes: Optional[int] = None, - top_k: Optional[int] = None, + top_k: Optional[int] = 1, threshold: float = 0.5, multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, @@ -1002,18 +1001,6 @@ def _stat_scores_compute(tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> Tens fp: False positives tn: True negatives fn: False negatives - - Example: - >>> preds = torch.tensor([1, 0, 2, 1]) - >>> target = torch.tensor([1, 1, 2, 0]) - >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='macro', num_classes=3) - >>> _stat_scores_compute(tp, fp, tn, fn) - tensor([[0, 1, 2, 1, 1], - [1, 1, 1, 1, 2], - [1, 0, 3, 0, 1]]) - >>> tp, fp, tn, fn = _stat_scores_update(preds, target, reduce='micro') - >>> _stat_scores_compute(tp, fp, tn, fn) - tensor([2, 2, 6, 2, 4]) """ stats = [ tp.unsqueeze(-1), @@ -1092,190 +1079,47 @@ def _reduce_stat_scores( def stat_scores( preds: Tensor, target: Tensor, - reduce: str = "micro", - mdmc_reduce: Optional[str] = None, - num_classes: Optional[int] = None, - top_k: Optional[int] = None, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, - multiclass: Optional[bool] = None, - ignore_index: Optional[int] = None, - task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, num_labels: Optional[int] = None, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - r"""Stat scores. - - .. note:: - From v0.10 an ``'binary_*'``, ``'multiclass_*'``, ``'multilabel_*'`` version now exist of each classification - metric. Moving forward we recommend using these versions. This base metric will still work as it did - prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required - and the general order of arguments may change, such that this metric will just function as an single - entrypoint to calling the three specialized versions. - - - Computes the number of true positives, false positives, true negatives, false negatives. - Related to `Type I and Type II errors`_ and the `confusion matrix`_. - - The reduction method (how the statistics are aggregated) is controlled by the - ``reduce`` parameter, and additionally by the ``mdmc_reduce`` parameter in the - multi-dimensional multi-class case. - - Args: - preds: Predictions from model (probabilities, logits or labels) - target: Ground truth values - threshold: - Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case - of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities. - top_k: - Number of highest probability or logit score predictions considered to find the correct label, - relevant only for (multi-dimensional) multi-class inputs. The - default value (``None``) will be interpreted as 1 for these inputs. - - Should be left at default (``None``) for all other types of inputs. - reduce: - Defines the reduction that is applied. Should be one of the following: - - - ``'micro'`` [default]: Counts the statistics by summing over all [sample, class] - combinations (globally). Each statistic is represented by a single integer. - - ``'macro'``: Counts the statistics for each class separately (over all samples). - Each statistic is represented by a ``(C,)`` tensor. Requires ``num_classes`` - to be set. - - ``'samples'``: Counts the statistics for each sample separately (over all classes). - Each statistic is represented by a ``(N, )`` 1d tensor. - - .. note:: What is considered a sample in the multi-dimensional multi-class case - depends on the value of ``mdmc_reduce``. - - num_classes: - Number of classes. Necessary for (multi-dimensional) multi-class or multi-label data. - ignore_index: - Specify a class (label) to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. If an index is ignored, and - ``reduce='macro'``, the class statistics for the ignored class will all be returned - as ``-1``. - mdmc_reduce: - Defines how the multi-dimensional multi-class inputs are handeled. Should be - one of the following: - - - ``None`` [default]: Should be left unchanged if your data is not multi-dimensional - multi-class. - - - ``'samplewise'``: In this case, the statistics are computed separately for each - sample on the ``N`` axis, and then the outputs are concatenated together. In each - sample the extra axes ``...`` are flattened to become the sub-sample axis, and - statistics for each sample are computed by treating the sub-sample axis as the - ``N`` axis for that sample. + r"""Computes the number of true positives, false positives, true negatives, false negatives and the support. - - ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs are - flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they - were ``(N_X, C)``. From here on the ``reduce`` parameter applies as usual. + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`binary_stat_scores`, :func:`multiclass_stat_scores` and :func:`multilabel_stat_scores` for the specific + details of each argument influence and examples. - multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - Return: - The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds - to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The - shape depends on the ``reduce`` and ``mdmc_reduce`` (in case of multi-dimensional - multi-class data) parameters: - - - If the data is not multi-dimensional multi-class, then - - - If ``reduce='micro'``, the shape will be ``(5, )`` - - If ``reduce='macro'``, the shape will be ``(C, 5)``, where ``C`` stands for the number of classes - - If ``reduce='samples'``, the shape will be ``(N, 5)``, where ``N`` stands for - the number of samples - - - If the data is multi-dimensional multi-class and ``mdmc_reduce='global'``, then - - - If ``reduce='micro'``, the shape will be ``(5, )`` - - If ``reduce='macro'``, the shape will be ``(C, 5)`` - - If ``reduce='samples'``, the shape will be ``(N*X, 5)``, where ``X`` stands for - the product of sizes of all "extra" dimensions of the data (i.e. all dimensions - except for ``C`` and ``N``) - - - If the data is multi-dimensional multi-class and ``mdmc_reduce='samplewise'``, then - - - If ``reduce='micro'``, the shape will be ``(N, 5)`` - - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` - - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` - - Raises: - ValueError: - If ``reduce`` is none of ``"micro"``, ``"macro"`` or ``"samples"``. - ValueError: - If ``mdmc_reduce`` is none of ``None``, ``"samplewise"``, ``"global"``. - ValueError: - If ``reduce`` is set to ``"macro"`` and ``num_classes`` is not provided. - ValueError: - If ``num_classes`` is set and ``ignore_index`` is not in the range ``[0, num_classes)``. - ValueError: - If ``ignore_index`` is used with ``binary data``. - ValueError: - If inputs are ``multi-dimensional multi-class`` and ``mdmc_reduce`` is not provided. - - Example: - >>> from torchmetrics.functional import stat_scores + Legacy Example: >>> preds = torch.tensor([1, 0, 2, 1]) >>> target = torch.tensor([1, 1, 2, 0]) - >>> stat_scores(preds, target, reduce='macro', num_classes=3) + >>> stat_scores(preds, target, task='multiclass', num_classes=3, average='micro') + tensor([2, 2, 6, 2, 4]) + >>> stat_scores(preds, target, task='multiclass', num_classes=3, average=None) tensor([[0, 1, 2, 1, 1], [1, 1, 1, 1, 2], [1, 0, 3, 0, 1]]) - >>> stat_scores(preds, target, reduce='micro') - tensor([2, 2, 6, 2, 4]) """ - if task is not None: - assert multidim_average is not None - if task == "binary": - return binary_stat_scores(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": - assert isinstance(num_classes, int) - assert isinstance(top_k, int) - return multiclass_stat_scores( - preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args - ) - if task == "multilabel": - assert isinstance(num_labels, int) - return multilabel_stat_scores( - preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args - ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + assert multidim_average is not None + if task == "binary": + return binary_stat_scores(preds, target, threshold, multidim_average, ignore_index, validate_args) + if task == "multiclass": + assert isinstance(num_classes, int) + assert isinstance(top_k, int) + return multiclass_stat_scores( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - else: - rank_zero_warn( - "From v0.10 an `'binary_*'`, `'multiclass_*'`, `'multilabel_*'` version now exist of each classification" - " metric. Moving forward we recommend using these versions. This base metric will still work as it did" - " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" - " and the general order of arguments may change, such that this metric will just function as an single" - " entrypoint to calling the three specialized versions.", - DeprecationWarning, + if task == "multilabel": + assert isinstance(num_labels, int) + return multilabel_stat_scores( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - if reduce not in ["micro", "macro", "samples"]: - raise ValueError(f"The `reduce` {reduce} is not valid.") - - if mdmc_reduce not in [None, "samplewise", "global"]: - raise ValueError(f"The `mdmc_reduce` {mdmc_reduce} is not valid.") - - if reduce == "macro" and (not num_classes or num_classes < 1): - raise ValueError("When you set `reduce` as 'macro', you have to provide the number of classes.") - - if num_classes and ignore_index is not None and (not 0 <= ignore_index < num_classes or num_classes == 1): - raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes") - - tp, fp, tn, fn = _stat_scores_update( - preds, - target, - reduce=reduce, - mdmc_reduce=mdmc_reduce, - top_k=top_k, - threshold=threshold, - num_classes=num_classes, - multiclass=multiclass, - ignore_index=ignore_index, + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" ) - return _stat_scores_compute(tp, fp, tn, fn) diff --git a/src/torchmetrics/utilities/checks.py b/src/torchmetrics/utilities/checks.py index 4936aac7a0b..f3d5d6e8caf 100644 --- a/src/torchmetrics/utilities/checks.py +++ b/src/torchmetrics/utilities/checks.py @@ -646,9 +646,9 @@ def check_forward_full_state_property( reps: number of repetitions of speedup test Example (states in ``update`` are independent, save to set ``full_state_update=False``) - >>> from torchmetrics import ConfusionMatrix + >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> check_forward_full_state_property( # doctest: +ELLIPSIS - ... ConfusionMatrix, + ... MulticlassConfusionMatrix, ... init_args = {'num_classes': 3}, ... input_args = {'preds': torch.randint(3, (100,)), 'target': torch.randint(3, (100,))}, ... ) @@ -661,8 +661,8 @@ def check_forward_full_state_property( Recommended setting `full_state_update=False` Example (states in ``update`` are dependend meaning that ``full_state_update=True``): - >>> from torchmetrics import ConfusionMatrix - >>> class MyMetric(ConfusionMatrix): + >>> from torchmetrics.classification import MulticlassConfusionMatrix + >>> class MyMetric(MulticlassConfusionMatrix): ... def update(self, preds, target): ... super().update(preds, target) ... # by construction make future states dependent on prior states diff --git a/src/torchmetrics/wrappers/bootstrapping.py b/src/torchmetrics/wrappers/bootstrapping.py index 0cf1485cf4a..cd06b26c6bf 100644 --- a/src/torchmetrics/wrappers/bootstrapping.py +++ b/src/torchmetrics/wrappers/bootstrapping.py @@ -69,15 +69,17 @@ class basically keeps multiple copies of the same base metric in memory and when Example:: >>> from pprint import pprint - >>> from torchmetrics import Accuracy, BootStrapper + >>> from torchmetrics import BootStrapper + >>> from torchmetrics.classification import MulticlassAccuracy >>> _ = torch.manual_seed(123) - >>> base_metric = Accuracy() + >>> base_metric = MulticlassAccuracy(num_classes=5, average='micro') >>> bootstrap = BootStrapper(base_metric, num_bootstraps=20) >>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,))) >>> output = bootstrap.compute() >>> pprint(output) {'mean': tensor(0.2205), 'std': tensor(0.0859)} """ + full_state_update: Optional[bool] = True def __init__( self, diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 82fc867e80c..e4fd999a323 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -30,38 +30,49 @@ class ClasswiseWrapper(Metric): Example: >>> import torch >>> _ = torch.manual_seed(42) - >>> from torchmetrics import Accuracy, ClasswiseWrapper - >>> metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None)) + >>> from torchmetrics import ClasswiseWrapper + >>> from torchmetrics.classification import MulticlassAccuracy + >>> metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) - >>> metric(preds, target) - {'accuracy_0': tensor(0.5000), 'accuracy_1': tensor(0.7500), 'accuracy_2': tensor(0.)} + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + {'multiclassaccuracy_0': tensor(0.5000), + 'multiclassaccuracy_1': tensor(0.7500), + 'multiclassaccuracy_2': tensor(0.)} Example (labels as list of strings): >>> import torch - >>> from torchmetrics import Accuracy, ClasswiseWrapper + >>> from torchmetrics import ClasswiseWrapper + >>> from torchmetrics.classification import MulticlassAccuracy >>> metric = ClasswiseWrapper( - ... Accuracy(num_classes=3, average=None), + ... MulticlassAccuracy(num_classes=3, average=None), ... labels=["horse", "fish", "dog"] ... ) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) - >>> metric(preds, target) - {'accuracy_horse': tensor(0.3333), 'accuracy_fish': tensor(0.6667), 'accuracy_dog': tensor(0.)} + >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE + {'multiclassaccuracy_horse': tensor(0.3333), + 'multiclassaccuracy_fish': tensor(0.6667), + 'multiclassaccuracy_dog': tensor(0.)} Example (in metric collection): >>> import torch - >>> from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall + >>> from torchmetrics import ClasswiseWrapper, MetricCollection + >>> from torchmetrics.classification import MulticlassAccuracy, MulticlassRecall >>> labels = ["horse", "fish", "dog"] >>> metric = MetricCollection( - ... {'accuracy': ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels), - ... 'recall': ClasswiseWrapper(Recall(num_classes=3, average=None), labels)} + ... {'multiclassaccuracy': ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), labels), + ... 'multiclassrecall': ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), labels)} ... ) >>> preds = torch.randn(10, 3).softmax(dim=-1) >>> target = torch.randint(3, (10,)) >>> metric(preds, target) # doctest: +NORMALIZE_WHITESPACE - {'accuracy_horse': tensor(0.), 'accuracy_fish': tensor(0.3333), 'accuracy_dog': tensor(0.4000), - 'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)} + {'multiclassaccuracy_horse': tensor(0.), + 'multiclassaccuracy_fish': tensor(0.3333), + 'multiclassaccuracy_dog': tensor(0.4000), + 'multiclassrecall_horse': tensor(0.), + 'multiclassrecall_fish': tensor(0.3333), + 'multiclassrecall_dog': tensor(0.4000)} """ def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None: diff --git a/src/torchmetrics/wrappers/minmax.py b/src/torchmetrics/wrappers/minmax.py index 713131f083f..fb99c940d57 100644 --- a/src/torchmetrics/wrappers/minmax.py +++ b/src/torchmetrics/wrappers/minmax.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import torch from torch import Tensor @@ -35,9 +35,10 @@ class MinMaxMetric(Metric): Example:: >>> import torch - >>> from torchmetrics import Accuracy + >>> from torchmetrics import MinMaxMetric + >>> from torchmetrics.classification import BinaryAccuracy >>> from pprint import pprint - >>> base_metric = Accuracy() + >>> base_metric = BinaryAccuracy() >>> minmax_metric = MinMaxMetric(base_metric) >>> preds_1 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]]) >>> preds_2 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]]) @@ -51,6 +52,7 @@ class MinMaxMetric(Metric): {'max': tensor(1.), 'min': tensor(0.7500), 'raw': tensor(0.7500)} """ + full_state_update: Optional[bool] = True min_val: Tensor max_val: Tensor diff --git a/src/torchmetrics/wrappers/tracker.py b/src/torchmetrics/wrappers/tracker.py index 25aaa19441a..59288e4fc4e 100644 --- a/src/torchmetrics/wrappers/tracker.py +++ b/src/torchmetrics/wrappers/tracker.py @@ -41,9 +41,10 @@ class MetricTracker(ModuleList): better (``True``) or lower is better (``False``). Example (single metric): - >>> from torchmetrics import Accuracy, MetricTracker + >>> from torchmetrics import MetricTracker + >>> from torchmetrics.classification import MulticlassAccuracy >>> _ = torch.manual_seed(42) - >>> tracker = MetricTracker(Accuracy(num_classes=10)) + >>> tracker = MetricTracker(MulticlassAccuracy(num_classes=10, average='micro')) >>> for epoch in range(5): ... tracker.increment() ... for batch_idx in range(5): diff --git a/tests/integrations/test_lightning.py b/tests/integrations/test_lightning.py index 68dc962cc4c..e05d71b3cb7 100644 --- a/tests/integrations/test_lightning.py +++ b/tests/integrations/test_lightning.py @@ -20,7 +20,8 @@ from integrations.helpers import no_warning_call from integrations.lightning.boring_model import BoringModel, RandomDataset -from torchmetrics import Accuracy, AveragePrecision, MetricCollection, SumMetric +from torchmetrics import MetricCollection, SumMetric +from torchmetrics.classification import BinaryAccuracy, BinaryAveragePrecision class DiffMetric(SumMetric): @@ -73,9 +74,9 @@ def __init__(self): self.layer = torch.nn.Linear(32, 1) for stage in ["train", "val", "test"]: - acc = Accuracy() + acc = BinaryAccuracy() acc.reset = mock.Mock(side_effect=acc.reset) - ap = AveragePrecision(num_classes=1, pos_label=1) + ap = BinaryAveragePrecision(num_classes=1, pos_label=1) ap.reset = mock.Mock(side_effect=ap.reset) self.add_module(f"acc_{stage}", acc) self.add_module(f"ap_{stage}", ap) diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 34b9d1af1c1..28e7b2e1de0 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -18,18 +18,20 @@ import pytest import torch -from torchmetrics import ( - AUROC, - Accuracy, - AveragePrecision, - CohenKappa, - ConfusionMatrix, - F1Score, - MatthewsCorrCoef, - Metric, - MetricCollection, - Precision, - Recall, +from torchmetrics import Metric, MetricCollection +from torchmetrics.classification import ( + BinaryAccuracy, + MulticlassAccuracy, + MulticlassAUROC, + MulticlassAveragePrecision, + MulticlassCohenKappa, + MulticlassConfusionMatrix, + MulticlassF1Score, + MulticlassMatthewsCorrCoef, + MulticlassPrecision, + MulticlassRecall, + MultilabelAUROC, + MultilabelAveragePrecision, ) from torchmetrics.utilities.checks import _allclose_recursive from unittests.helpers import seed_all @@ -293,15 +295,15 @@ def update(self, preds, target, kwarg2): def compute(self): return - mc = MetricCollection([Accuracy(), DummyMetric()]) + mc = MetricCollection([BinaryAccuracy(), DummyMetric()]) mc2 = MetricCollection([MyAccuracy(), DummyMetric()]) mc(torch.tensor([0, 1]), torch.tensor([0, 1]), kwarg="kwarg") mc2(torch.tensor([0, 1]), torch.tensor([0, 1]), kwarg="kwarg", kwarg2="kwarg2") # function for generating -_mc_preds = torch.randn(10, 3).softmax(dim=-1) -_mc_target = torch.randint(3, (10,)) +_mc_preds = torch.randn(10, 3, 2).softmax(dim=1) +_mc_target = torch.randint(3, (10, 2)) _ml_preds = torch.rand(10, 3) _ml_target = torch.randint(2, (10, 3)) @@ -310,29 +312,49 @@ def compute(self): "metrics, expected, preds, target", [ # single metric forms its own compute group - (Accuracy(3), {0: ["Accuracy"]}, _mc_preds, _mc_target), + (MulticlassAccuracy(num_classes=3), {0: ["MulticlassAccuracy"]}, _mc_preds, _mc_target), # two metrics of same class forms a compute group - ({"acc0": Accuracy(3), "acc1": Accuracy(3)}, {0: ["acc0", "acc1"]}, _mc_preds, _mc_target), + ( + {"acc0": MulticlassAccuracy(num_classes=3), "acc1": MulticlassAccuracy(num_classes=3)}, + {0: ["acc0", "acc1"]}, + _mc_preds, + _mc_target, + ), # two metrics from registry froms a compute group - ([Precision(3), Recall(3)], {0: ["Precision", "Recall"]}, _mc_preds, _mc_target), + ( + [MulticlassPrecision(num_classes=3), MulticlassRecall(num_classes=3)], + {0: ["MulticlassPrecision", "MulticlassRecall"]}, + _mc_preds, + _mc_target, + ), # two metrics from different classes gives two compute groups - ([ConfusionMatrix(3), Recall(3)], {0: ["ConfusionMatrix"], 1: ["Recall"]}, _mc_preds, _mc_target), + ( + [MulticlassConfusionMatrix(num_classes=3), MulticlassRecall(num_classes=3)], + {0: ["MulticlassConfusionMatrix"], 1: ["MulticlassRecall"]}, + _mc_preds, + _mc_target, + ), # multi group multi metric ( - [ConfusionMatrix(3), CohenKappa(3), Recall(3), Precision(3)], - {0: ["ConfusionMatrix", "CohenKappa"], 1: ["Recall", "Precision"]}, + [ + MulticlassConfusionMatrix(num_classes=3), + MulticlassCohenKappa(num_classes=3), + MulticlassRecall(num_classes=3), + MulticlassPrecision(num_classes=3), + ], + {0: ["MulticlassConfusionMatrix", "MulticlassCohenKappa"], 1: ["MulticlassRecall", "MulticlassPrecision"]}, _mc_preds, _mc_target, ), # Complex example ( { - "acc": Accuracy(3), - "acc2": Accuracy(3), - "acc3": Accuracy(num_classes=3, average="macro"), - "f1": F1Score(3), - "recall": Recall(3), - "confmat": ConfusionMatrix(3), + "acc": MulticlassAccuracy(num_classes=3), + "acc2": MulticlassAccuracy(num_classes=3), + "acc3": MulticlassAccuracy(num_classes=3, multidim_average="samplewise"), + "f1": MulticlassF1Score(num_classes=3), + "recall": MulticlassRecall(num_classes=3), + "confmat": MulticlassConfusionMatrix(num_classes=3), }, {0: ["acc", "acc2", "f1", "recall"], 1: ["acc3"], 2: ["confmat"]}, _mc_preds, @@ -340,8 +362,11 @@ def compute(self): ), # With list states ( - [AUROC(average="macro", num_classes=3), AveragePrecision(average="macro", num_classes=3)], - {0: ["AUROC", "AveragePrecision"]}, + [ + MulticlassAUROC(num_classes=3, average="macro"), + MulticlassAveragePrecision(num_classes=3, average="macro"), + ], + {0: ["MulticlassAUROC", "MulticlassAveragePrecision"]}, _mc_preds, _mc_target, ), @@ -349,17 +374,24 @@ def compute(self): ( [ MetricCollection( - AUROC(average="micro", num_classes=3), - AveragePrecision(average="micro", num_classes=3), + MultilabelAUROC(num_labels=3, average="micro"), + MultilabelAveragePrecision(num_labels=3, average="micro"), postfix="_micro", ), MetricCollection( - AUROC(average="macro", num_classes=3), - AveragePrecision(average="macro", num_classes=3), + MultilabelAUROC(num_labels=3, average="macro"), + MultilabelAveragePrecision(num_labels=3, average="macro"), postfix="_macro", ), ], - {0: ["AUROC_micro", "AveragePrecision_micro", "AUROC_macro", "AveragePrecision_macro"]}, + { + 0: [ + "MultilabelAUROC_micro", + "MultilabelAveragePrecision_micro", + "MultilabelAUROC_macro", + "MultilabelAveragePrecision_macro", + ] + }, _ml_preds, _ml_target, ), @@ -447,20 +479,20 @@ def _compare(m1, m2): @pytest.mark.parametrize( "metrics", [ - {"acc0": Accuracy(3), "acc1": Accuracy(3)}, - [Precision(3), Recall(3)], - [ConfusionMatrix(3), CohenKappa(3), Recall(3), Precision(3)], + {"acc0": MulticlassAccuracy(3), "acc1": MulticlassAccuracy(3)}, + [MulticlassPrecision(3), MulticlassRecall(3)], + [MulticlassConfusionMatrix(3), MulticlassCohenKappa(3), MulticlassRecall(3), MulticlassPrecision(3)], { - "acc": Accuracy(3), - "acc2": Accuracy(3), - "acc3": Accuracy(num_classes=3, average="macro"), - "f1": F1Score(3), - "recall": Recall(3), - "confmat": ConfusionMatrix(3), + "acc": MulticlassAccuracy(3), + "acc2": MulticlassAccuracy(3), + "acc3": MulticlassAccuracy(num_classes=3, average="macro"), + "f1": MulticlassF1Score(3), + "recall": MulticlassRecall(3), + "confmat": MulticlassConfusionMatrix(3), }, ], ) -@pytest.mark.parametrize("steps", [100, 1000]) +@pytest.mark.parametrize("steps", [1000]) def test_check_compute_groups_is_faster(metrics, steps): """Check that compute groups are formed after initialization.""" m = MetricCollection(deepcopy(metrics), compute_groups=True) @@ -486,12 +518,15 @@ def test_check_compute_groups_is_faster(metrics, steps): def test_compute_group_define_by_user(): """Check that user can provide compute groups.""" m = MetricCollection( - ConfusionMatrix(3), Recall(3), Precision(3), compute_groups=[["ConfusionMatrix"], ["Recall", "Precision"]] + MulticlassConfusionMatrix(3), + MulticlassRecall(3), + MulticlassPrecision(3), + compute_groups=[["MulticlassConfusionMatrix"], ["MulticlassRecall", "MulticlassPrecision"]], ) # Check that we are not going to check the groups in the first update assert m._groups_checked - assert m.compute_groups == {0: ["ConfusionMatrix"], 1: ["Recall", "Precision"]} + assert m.compute_groups == {0: ["MulticlassConfusionMatrix"], 1: ["MulticlassRecall", "MulticlassPrecision"]} preds = torch.randn(10, 3).softmax(dim=-1) target = torch.randint(3, (10,)) @@ -503,25 +538,28 @@ def test_compute_on_different_dtype(): """Check that extraction of compute groups are robust towards difference in dtype.""" m = MetricCollection( [ - ConfusionMatrix(num_classes=3), - MatthewsCorrCoef(num_classes=3), + MulticlassConfusionMatrix(num_classes=3), + MulticlassMatthewsCorrCoef(num_classes=3), ] ) assert not m._groups_checked - assert m.compute_groups == {0: ["ConfusionMatrix"], 1: ["MatthewsCorrCoef"]} + assert m.compute_groups == {0: ["MulticlassConfusionMatrix"], 1: ["MulticlassMatthewsCorrCoef"]} preds = torch.randn(10, 3).softmax(dim=-1) target = torch.randint(3, (10,)) for _ in range(2): m.update(preds, target) - assert m.compute_groups == {0: ["ConfusionMatrix", "MatthewsCorrCoef"]} + assert m.compute_groups == {0: ["MulticlassConfusionMatrix", "MulticlassMatthewsCorrCoef"]} assert m.compute() def test_error_on_wrong_specified_compute_groups(): """Test that error is raised if user mis-specify the compute groups.""" - with pytest.raises(ValueError, match="Input Accuracy in `compute_groups`.*"): + with pytest.raises(ValueError, match="Input MulticlassAccuracy in `compute_groups`.*"): MetricCollection( - ConfusionMatrix(3), Recall(3), Precision(3), compute_groups=[["ConfusionMatrix"], ["Recall", "Accuracy"]] + MulticlassConfusionMatrix(3), + MulticlassRecall(3), + MulticlassPrecision(3), + compute_groups=[["MulticlassConfusionMatrix"], ["MulticlassRecall", "MulticlassAccuracy"]], ) @@ -530,18 +568,32 @@ def test_error_on_wrong_specified_compute_groups(): [ [ MetricCollection( - [Accuracy(num_classes=3, average="macro"), Precision(num_classes=3, average="macro")], prefix="macro_" + [ + MulticlassAccuracy(num_classes=3, average="macro"), + MulticlassPrecision(num_classes=3, average="macro"), + ], + prefix="macro_", ), MetricCollection( - [Accuracy(num_classes=3, average="micro"), Precision(num_classes=3, average="micro")], prefix="micro_" + [ + MulticlassAccuracy(num_classes=3, average="micro"), + MulticlassPrecision(num_classes=3, average="micro"), + ], + prefix="micro_", ), ], { "macro": MetricCollection( - [Accuracy(num_classes=3, average="macro"), Precision(num_classes=3, average="macro")] + [ + MulticlassAccuracy(num_classes=3, average="macro"), + MulticlassPrecision(num_classes=3, average="macro"), + ] ), "micro": MetricCollection( - [Accuracy(num_classes=3, average="micro"), Precision(num_classes=3, average="micro")] + [ + MulticlassAccuracy(num_classes=3, average="micro"), + MulticlassPrecision(num_classes=3, average="micro"), + ] ), }, ], @@ -552,7 +604,7 @@ def test_nested_collections(input_collections): preds = torch.randn(10, 3).softmax(dim=-1) target = torch.randint(3, (10,)) val = metrics(preds, target) - assert "valmetrics/macro_Accuracy" in val - assert "valmetrics/macro_Precision" in val - assert "valmetrics/micro_Accuracy" in val - assert "valmetrics/micro_Precision" in val + assert "valmetrics/macro_MulticlassAccuracy" in val + assert "valmetrics/macro_MulticlassPrecision" in val + assert "valmetrics/micro_MulticlassAccuracy" in val + assert "valmetrics/micro_MulticlassPrecision" in val diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index ae2ae4d75cf..46ea1ca988b 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -24,7 +24,8 @@ from torch import Tensor, tensor from torch.nn import Module -from torchmetrics import Accuracy, PearsonCorrCoef +from torchmetrics import PearsonCorrCoef +from torchmetrics.classification import BinaryAccuracy from unittests.helpers import seed_all from unittests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum from unittests.helpers.utilities import no_warning_call @@ -443,7 +444,7 @@ def forward(self, *args, **kwargs): def test_custom_availability_check_and_sync_fn(): dummy_availability_check = Mock(return_value=True) dummy_dist_sync_fn = Mock(wraps=lambda x, group: [x]) - acc = Accuracy(dist_sync_fn=dummy_dist_sync_fn, distributed_available_fn=dummy_availability_check) + acc = BinaryAccuracy(dist_sync_fn=dummy_dist_sync_fn, distributed_available_fn=dummy_availability_check) acc.update(torch.tensor([[1], [1], [1], [1]]), torch.tensor([[1], [1], [1], [1]])) dummy_dist_sync_fn.assert_not_called() diff --git a/tests/unittests/wrappers/test_bootstrapping.py b/tests/unittests/wrappers/test_bootstrapping.py index 6b5ea1ba6a1..fa57afe6885 100644 --- a/tests/unittests/wrappers/test_bootstrapping.py +++ b/tests/unittests/wrappers/test_bootstrapping.py @@ -20,7 +20,8 @@ from sklearn.metrics import mean_squared_error, precision_score, recall_score from torch import Tensor -from torchmetrics import MeanSquaredError, Precision, Recall +from torchmetrics import MeanSquaredError +from torchmetrics.classification import MulticlassPrecision, MulticlassRecall from torchmetrics.utilities import apply_to_collection from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler from unittests.helpers import seed_all @@ -78,8 +79,8 @@ def test_bootstrap_sampler(sampling_strategy): @pytest.mark.parametrize( "metric, sk_metric", [ - [Precision(average="micro"), partial(precision_score, average="micro")], - [Recall(average="micro"), partial(recall_score, average="micro")], + [MulticlassPrecision(num_classes=10, average="micro"), partial(precision_score, average="micro")], + [MulticlassRecall(num_classes=10, average="micro"), partial(recall_score, average="micro")], [MeanSquaredError(), mean_squared_error], ], ) diff --git a/tests/unittests/wrappers/test_classwise.py b/tests/unittests/wrappers/test_classwise.py index e5c8fbf2415..e2e0565c2a9 100644 --- a/tests/unittests/wrappers/test_classwise.py +++ b/tests/unittests/wrappers/test_classwise.py @@ -1,7 +1,8 @@ import pytest import torch -from torchmetrics import Accuracy, ClasswiseWrapper, MetricCollection, Recall +from torchmetrics import ClasswiseWrapper, MetricCollection +from torchmetrics.classification import MulticlassAccuracy, MulticlassRecall def test_raises_error_on_wrong_input(): @@ -10,13 +11,13 @@ def test_raises_error_on_wrong_input(): ClasswiseWrapper([]) with pytest.raises(ValueError, match="Expected argument `labels` to either be `None` or a list of strings.*"): - ClasswiseWrapper(Accuracy(), "hest") + ClasswiseWrapper(MulticlassAccuracy(num_classes=3), "hest") def test_output_no_labels(): """Test that wrapper works with no label input.""" - base = Accuracy(num_classes=3, average=None) - metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None)) + base = MulticlassAccuracy(num_classes=3, average=None) + metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) for _ in range(2): preds = torch.randn(20, 3).softmax(dim=-1) target = torch.randint(3, (20,)) @@ -25,15 +26,15 @@ def test_output_no_labels(): assert isinstance(val, dict) assert len(val) == 3 for i in range(3): - assert f"accuracy_{i}" in val - assert val[f"accuracy_{i}"] == val_base[i] + assert f"multiclassaccuracy_{i}" in val + assert val[f"multiclassaccuracy_{i}"] == val_base[i] def test_output_with_labels(): """Test that wrapper works with label input.""" labels = ["horse", "fish", "cat"] - base = Accuracy(num_classes=3, average=None) - metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels) + base = MulticlassAccuracy(num_classes=3, average=None) + metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), labels=labels) for _ in range(2): preds = torch.randn(20, 3).softmax(dim=-1) target = torch.randint(3, (20,)) @@ -42,15 +43,15 @@ def test_output_with_labels(): assert isinstance(val, dict) assert len(val) == 3 for i, lab in enumerate(labels): - assert f"accuracy_{lab}" in val - assert val[f"accuracy_{lab}"] == val_base[i] + assert f"multiclassaccuracy_{lab}" in val + assert val[f"multiclassaccuracy_{lab}"] == val_base[i] val = metric.compute() val_base = base.compute() assert isinstance(val, dict) assert len(val) == 3 for i, lab in enumerate(labels): - assert f"accuracy_{lab}" in val - assert val[f"accuracy_{lab}"] == val_base[i] + assert f"multiclassaccuracy_{lab}" in val + assert val[f"multiclassaccuracy_{lab}"] == val_base[i] @pytest.mark.parametrize("prefix", [None, "pre_"]) @@ -60,8 +61,8 @@ def test_using_metriccollection(prefix, postfix): labels = ["horse", "fish", "cat"] metric = MetricCollection( { - "accuracy": ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels), - "recall": ClasswiseWrapper(Recall(num_classes=3, average=None), labels=labels), + "accuracy": ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), labels=labels), + "recall": ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), labels=labels), }, prefix=prefix, postfix=postfix, @@ -78,7 +79,7 @@ def _get_correct_name(base): return name for lab in labels: - name = _get_correct_name(f"accuracy_{lab}") + name = _get_correct_name(f"multiclassaccuracy_{lab}") assert name in val - name = _get_correct_name(f"recall_{lab}") + name = _get_correct_name(f"multiclassrecall_{lab}") assert name in val diff --git a/tests/unittests/wrappers/test_minmax.py b/tests/unittests/wrappers/test_minmax.py index 90e54bd2eeb..8be6b288194 100644 --- a/tests/unittests/wrappers/test_minmax.py +++ b/tests/unittests/wrappers/test_minmax.py @@ -5,7 +5,8 @@ import torch from torch import Tensor -from torchmetrics import Accuracy, ConfusionMatrix, MeanSquaredError +from torchmetrics import MeanSquaredError +from torchmetrics.classification import BinaryAccuracy, BinaryConfusionMatrix, MulticlassAccuracy from torchmetrics.wrappers import MinMaxMetric from unittests.helpers import seed_all from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester @@ -56,7 +57,7 @@ def compare_fn_ddp(preds, target, base_fn): ( torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES).softmax(dim=-1), torch.randint(NUM_CLASSES, (NUM_BATCHES, BATCH_SIZE)), - Accuracy(num_classes=NUM_CLASSES), + MulticlassAccuracy(num_classes=NUM_CLASSES), ), (torch.randn(NUM_BATCHES, BATCH_SIZE), torch.randn(NUM_BATCHES, BATCH_SIZE), MeanSquaredError()), ], @@ -96,7 +97,7 @@ def test_minmax_wrapper(self, preds, target, base_metric, ddp): ) def test_basic_example(preds, labels, raws, maxs, mins) -> None: """tests that both min and max versions of MinMaxMetric operate correctly after calling compute.""" - acc = Accuracy() + acc = BinaryAccuracy() min_max_acc = MinMaxMetric(acc) labels = Tensor(labels).long() @@ -117,7 +118,7 @@ def test_no_base_metric() -> None: def test_no_scalar_compute() -> None: """tests that an assertion error is thrown if the wrapped basemetric gives a non-scalar on compute.""" - min_max_nsm = MinMaxMetric(ConfusionMatrix(num_classes=2)) + min_max_nsm = MinMaxMetric(BinaryConfusionMatrix(num_classes=2)) with pytest.raises(RuntimeError, match=r"Returned value from base metric should be a scalar .*"): min_max_nsm.compute() diff --git a/tests/unittests/wrappers/test_multioutput.py b/tests/unittests/wrappers/test_multioutput.py index 8ed1bb139e7..00398bdb92d 100644 --- a/tests/unittests/wrappers/test_multioutput.py +++ b/tests/unittests/wrappers/test_multioutput.py @@ -8,7 +8,7 @@ from torch import Tensor from torchmetrics import Metric -from torchmetrics.classification import Accuracy +from torchmetrics.classification import MulticlassAccuracy from torchmetrics.regression import R2Score from torchmetrics.wrappers.multioutput import MultioutputWrapper from unittests.helpers import seed_all @@ -94,12 +94,12 @@ def _multi_target_sk_accuracy(preds, target, num_outputs): {}, ), ( - Accuracy, + MulticlassAccuracy, partial(_multi_target_sk_accuracy, num_outputs=2), _multi_target_classification_inputs.preds, _multi_target_classification_inputs.target, num_targets, - dict(num_classes=NUM_CLASSES), + dict(num_classes=NUM_CLASSES, average="micro"), ), ], ) diff --git a/tests/unittests/wrappers/test_tracker.py b/tests/unittests/wrappers/test_tracker.py index a112cfb9bae..a2981d2d126 100644 --- a/tests/unittests/wrappers/test_tracker.py +++ b/tests/unittests/wrappers/test_tracker.py @@ -15,14 +15,12 @@ import pytest import torch -from torchmetrics import ( - Accuracy, - ConfusionMatrix, - MeanAbsoluteError, - MeanSquaredError, - MetricCollection, - Precision, - Recall, +from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection +from torchmetrics.classification import ( + MulticlassAccuracy, + MulticlassConfusionMatrix, + MulticlassPrecision, + MulticlassRecall, ) from torchmetrics.wrappers import MetricTracker from unittests.helpers import seed_all @@ -53,7 +51,7 @@ def test_raises_error_on_wrong_input(): ], ) def test_raises_error_if_increment_not_called(method, method_input): - tracker = MetricTracker(Accuracy(num_classes=10)) + tracker = MetricTracker(MulticlassAccuracy(num_classes=10)) with pytest.raises(ValueError, match=f"`{method}` cannot be called before .*"): if method_input is not None: getattr(tracker, method)(*method_input) @@ -64,18 +62,30 @@ def test_raises_error_if_increment_not_called(method, method_input): @pytest.mark.parametrize( "base_metric, metric_input, maximize", [ - (Accuracy(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), - (Precision(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), - (Recall(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), + (MulticlassAccuracy(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), + (MulticlassPrecision(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), + (MulticlassRecall(num_classes=10), (torch.randint(10, (50,)), torch.randint(10, (50,))), True), (MeanSquaredError(), (torch.randn(50), torch.randn(50)), False), (MeanAbsoluteError(), (torch.randn(50), torch.randn(50)), False), ( - MetricCollection([Accuracy(num_classes=10), Precision(num_classes=10), Recall(num_classes=10)]), + MetricCollection( + [ + MulticlassAccuracy(num_classes=10), + MulticlassPrecision(num_classes=10), + MulticlassRecall(num_classes=10), + ] + ), (torch.randint(10, (50,)), torch.randint(10, (50,))), True, ), ( - MetricCollection([Accuracy(num_classes=10), Precision(num_classes=10), Recall(num_classes=10)]), + MetricCollection( + [ + MulticlassAccuracy(num_classes=10), + MulticlassPrecision(num_classes=10), + MulticlassRecall(num_classes=10), + ] + ), (torch.randint(10, (50,)), torch.randint(10, (50,))), [True, True, True], ), @@ -133,8 +143,8 @@ def test_tracker(base_metric, metric_input, maximize): @pytest.mark.parametrize( "base_metric", [ - ConfusionMatrix(3), - MetricCollection([ConfusionMatrix(3), Accuracy(3)]), + MulticlassConfusionMatrix(3), + MetricCollection([MulticlassConfusionMatrix(3), MulticlassAccuracy(3)]), ], ) def test_best_metric_for_not_well_defined_metric_collection(base_metric): @@ -149,8 +159,8 @@ def test_best_metric_for_not_well_defined_metric_collection(base_metric): with pytest.warns(UserWarning, match="Encountered the following error when trying to get the best metric.*"): best = tracker.best_metric() if isinstance(best, dict): - assert best["Accuracy"] is not None - assert best["ConfusionMatrix"] is None + assert best["MulticlassAccuracy"] is not None + assert best["MulticlassConfusionMatrix"] is None else: assert best is None @@ -158,10 +168,10 @@ def test_best_metric_for_not_well_defined_metric_collection(base_metric): idx, best = tracker.best_metric(return_step=True) if isinstance(best, dict): - assert best["Accuracy"] is not None - assert best["ConfusionMatrix"] is None - assert idx["Accuracy"] is not None - assert idx["ConfusionMatrix"] is None + assert best["MulticlassAccuracy"] is not None + assert best["MulticlassConfusionMatrix"] is None + assert idx["MulticlassAccuracy"] is not None + assert idx["MulticlassConfusionMatrix"] is None else: assert best is None assert idx is None