diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c6efbc16ce..7b8e75e2e36 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `NormalizedRootMeanSquaredError` metric to regression subpackage ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442)) +- Added `NegativePredictiveValue` to classification metrics ([#2433](https://github.com/Lightning-AI/torchmetrics/pull/2433)) + + - Added method `merge_state` to `Metric` ([#2786](https://github.com/Lightning-AI/torchmetrics/pull/2786)) diff --git a/docs/source/classification/negative_predictive_value.rst b/docs/source/classification/negative_predictive_value.rst new file mode 100644 index 00000000000..c35ef32c2fa --- /dev/null +++ b/docs/source/classification/negative_predictive_value.rst @@ -0,0 +1,56 @@ +.. customcarditem:: + :header: Negative Predictive Value + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Classification + +.. include:: ../links.rst + +######################### +Negative Predictive Value +######################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.NegativePredictiveValue + :exclude-members: update, compute + :special-members: __new__ + +BinaryNegativePredictiveValue +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryNegativePredictiveValue + :exclude-members: update, compute + +MulticlassNegativePredictiveValue +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassNegativePredictiveValue + :exclude-members: update, compute + +MultilabelNegativePredictiveValue +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelNegativePredictiveValue + :exclude-members: update, compute + + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.negative_predictive_value + +binary_negative_predictive_value +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_negative_predictive_value + +multiclass_negative_predictive_value +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_negative_predictive_value + +multilabel_negative_predictive_value +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_negative_predictive_value diff --git a/docs/source/links.rst b/docs/source/links.rst index ed01989b3cc..4f2cbe6ad53 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -177,3 +177,4 @@ .. _Hausdorff Distance: https://en.wikipedia.org/wiki/Hausdorff_distance .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html .. _Procrustes Disparity: https://en.wikipedia.org/wiki/Procrustes_analysis +.. _Negative Predictive Value: https://en.wikipedia.org/wiki/Positive_and_negative_predictive_values diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index a6105df3480..fee1596eb1d 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -62,6 +62,7 @@ HingeLoss, JaccardIndex, MatthewsCorrCoef, + NegativePredictiveValue, Precision, PrecisionAtFixedRecall, PrecisionRecallCurve, @@ -207,6 +208,7 @@ "MultiScaleStructuralSimilarityIndexMeasure", "MultioutputWrapper", "MultitaskWrapper", + "NegativePredictiveValue", "NormalizedRootMeanSquaredError", "PanopticQuality", "PeakSignalNoiseRatio", diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 988a01c2947..86f334c970d 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -63,6 +63,12 @@ MulticlassMatthewsCorrCoef, MultilabelMatthewsCorrCoef, ) +from torchmetrics.classification.negative_predictive_value import ( + BinaryNegativePredictiveValue, + MulticlassNegativePredictiveValue, + MultilabelNegativePredictiveValue, + NegativePredictiveValue, +) from torchmetrics.classification.precision_fixed_recall import ( BinaryPrecisionAtFixedRecall, MulticlassPrecisionAtFixedRecall, @@ -217,4 +223,8 @@ "MulticlassSensitivityAtSpecificity", "MultilabelSensitivityAtSpecificity", "SensitivityAtSpecificity", + "BinaryNegativePredictiveValue", + "MulticlassNegativePredictiveValue", + "MultilabelNegativePredictiveValue", + "NegativePredictiveValue", ] diff --git a/src/torchmetrics/classification/negative_predictive_value.py b/src/torchmetrics/classification/negative_predictive_value.py new file mode 100644 index 00000000000..d0b19dc1247 --- /dev/null +++ b/src/torchmetrics/classification/negative_predictive_value.py @@ -0,0 +1,521 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Sequence, Type, Union + +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.classification.base import _ClassificationTaskWrapper +from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores +from torchmetrics.functional.classification.negative_predictive_value import _negative_predictive_value_reduce +from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = [ + "BinaryNegativePredictiveValue.plot", + "MulticlassNegativePredictiveValue.plot", + "MultilabelNegativePredictiveValue.plot", + ] + + +class BinaryNegativePredictiveValue(BinaryStatScores): + r"""Compute `Negative Predictive Value`_ for binary tasks. + + .. math:: \text{Negative Predictive Value} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives + respectively. The metric is only proper defined when :math:`\text{TN} + \text{FP} \neq 0`. If this case is + encountered a score of 0 is returned. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating point + tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per + element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``npv`` (:class:`~torch.Tensor`): If ``multidim_average`` is set to ``global``, the metric returns a scalar value. + If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value + per sample. + + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + + Args: + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torch import tensor + >>> from torchmetrics.classification import BinaryNegativePredictiveValue + >>> target = tensor([0, 1, 0, 1, 0, 1]) + >>> preds = tensor([0, 0, 1, 1, 0, 1]) + >>> metric = BinaryNegativePredictiveValue() + >>> metric(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.classification import BinaryNegativePredictiveValue + >>> target = tensor([0, 1, 0, 1, 0, 1]) + >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> metric = BinaryNegativePredictiveValue() + >>> metric(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.classification import BinaryNegativePredictiveValue + >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) + >>> metric = BinaryNegativePredictiveValue(multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.0000, 0.2500]) + + """ + + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + def compute(self) -> Tensor: + """Compute metric.""" + tp, fp, tn, fn = self._final_state() + return _negative_predictive_value_reduce( + tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average + ) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import rand, randint + >>> # Example plotting a single value + >>> from torchmetrics.classification import BinaryNegativePredictiveValue + >>> metric = BinaryNegativePredictiveValue() + >>> metric.update(rand(10), randint(2,(10,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import rand, randint + >>> # Example plotting multiple values + >>> from torchmetrics.classification import BinaryNegativePredictiveValue + >>> metric = BinaryNegativePredictiveValue() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(rand(10), randint(2,(10,)))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) + + +class MulticlassNegativePredictiveValue(MulticlassStatScores): + r"""Compute `Negative Predictive Value`_ for multiclass tasks. + + .. math:: \text{Negative Predictive Value} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives + respectively. The metric is only proper defined when :math:`\text{TN} + \text{FP} \neq 0`. If this case is + encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be + affected in turn. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` or float tensor of shape ``(N, C, ..)``. + If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert + probabilities/logits into an int tensor. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``npv`` (:class:`~torch.Tensor`): The returned shape depends on the ``average`` and ``multidim_average`` + arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + + Args: + num_classes: Integer specifying the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction + + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torch import tensor + >>> from torchmetrics.classification import MulticlassNegativePredictiveValue + >>> target = tensor([2, 1, 0, 0]) + >>> preds = tensor([2, 1, 0, 1]) + >>> metric = MulticlassNegativePredictiveValue(num_classes=3) + >>> metric(preds, target) + tensor(0.8889) + >>> metric = MulticlassNegativePredictiveValue(num_classes=3, average=None) + >>> metric(preds, target) + tensor([0.6667, 1.0000, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MulticlassNegativePredictiveValue + >>> target = tensor([2, 1, 0, 0]) + >>> preds = tensor([[0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13]]) + >>> metric = MulticlassNegativePredictiveValue(num_classes=3) + >>> metric(preds, target) + tensor(0.8889) + >>> metric = MulticlassNegativePredictiveValue(num_classes=3, average=None) + >>> metric(preds, target) + tensor([0.6667, 1.0000, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MulticlassNegativePredictiveValue + >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = MulticlassNegativePredictiveValue(num_classes=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.7833, 0.6556]) + >>> metric = MulticlassNegativePredictiveValue(num_classes=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[1.0000, 0.6000, 0.7500], + [0.8000, 0.5000, 0.6667]]) + + """ + + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + plot_legend_name: str = "Class" + + def compute(self) -> Tensor: + """Compute metric.""" + tp, fp, tn, fn = self._final_state() + return _negative_predictive_value_reduce( + tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k + ) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randint + >>> # Example plotting a single value per class + >>> from torchmetrics.classification import MulticlassNegativePredictiveValue + >>> metric = MulticlassNegativePredictiveValue(num_classes=3, average=None) + >>> metric.update(randint(3, (20,)), randint(3, (20,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randint + >>> # Example plotting a multiple values per class + >>> from torchmetrics.classification import MulticlassNegativePredictiveValue + >>> metric = MulticlassNegativePredictiveValue(num_classes=3, average=None) + >>> values = [] + >>> for _ in range(20): + ... values.append(metric(randint(3, (20,)), randint(3, (20,)))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) + + +class MultilabelNegativePredictiveValue(MultilabelStatScores): + r"""Compute `Negative Predictive Value`_ for multilabel tasks. + + .. math:: \text{Negative Predictive Value} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives + respectively. The metric is only proper defined when :math:`\text{TN} + \text{FP} \neq 0`. If this case is + encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be + affected in turn. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, C, ...)``. If preds is a floating + point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid + per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)`` + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``npv`` (:class:`~torch.Tensor`): The returned shape depends on the ``average`` and ``multidim_average`` + arguments: + + - If ``multidim_average`` is set to ``global`` + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise`` + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + If ``multidim_average`` is set to ``samplewise`` we expect at least one additional dimension ``...`` to be present, + which the reduction will then be applied over instead of the sample dimension ``N``. + + Args: + num_labels: Integer specifying the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: calculates statistic for each label and applies no reduction + + multidim_average: Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torch import tensor + >>> from torchmetrics.classification import MultilabelNegativePredictiveValue + >>> target = tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelNegativePredictiveValue(num_labels=3) + >>> metric(preds, target) + tensor(0.5000) + >>> mls = MultilabelNegativePredictiveValue(num_labels=3, average=None) + >>> mls(preds, target) + tensor([1.0000, 0.5000, 0.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.classification import MultilabelNegativePredictiveValue + >>> target = tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelNegativePredictiveValue(num_labels=3) + >>> metric(preds, target) + tensor(0.5000) + >>> mls = MultilabelNegativePredictiveValue(num_labels=3, average=None) + >>> mls(preds, target) + tensor([1.0000, 0.5000, 0.0000]) + + Example (multidim tensors): + >>> from torchmetrics.classification import MultilabelNegativePredictiveValue + >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) + >>> metric = MultilabelNegativePredictiveValue(num_labels=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.0000, 0.1667]) + >>> mls = MultilabelNegativePredictiveValue(num_labels=3, multidim_average='samplewise', average=None) + >>> mls(preds, target) + tensor([[0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.5000]]) + + """ + + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + plot_legend_name: str = "Label" + + def compute(self) -> Tensor: + """Compute metric.""" + tp, fp, tn, fn = self._final_state() + return _negative_predictive_value_reduce( + tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True + ) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling ``metric.forward`` or ``metric.compute`` or a list of these + results. If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure object and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import rand, randint + >>> # Example plotting a single value + >>> from torchmetrics.classification import MultilabelNegativePredictiveValue + >>> metric = MultilabelNegativePredictiveValue(num_labels=3) + >>> metric.update(randint(2, (20, 3)), randint(2, (20, 3))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import rand, randint + >>> # Example plotting multiple values + >>> from torchmetrics.classification import MultilabelNegativePredictiveValue + >>> metric = MultilabelNegativePredictiveValue(num_labels=3) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(randint(2, (20, 3)), randint(2, (20, 3)))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) + + +class NegativePredictiveValue(_ClassificationTaskWrapper): + r"""Compute `Negative Predictive Value`_. + + .. math:: \text{Negative Predictive Value} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives + respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is + encountered for any class/label, the metric for that class/label will be set to 0 and the overall metric may + therefore be affected in turn. + + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :class:`~torchmetrics.classification.BinaryNegativePredictiveValue`, + :class:`~torchmetrics.classification.MulticlassNegativePredictiveValue` + and :class:`~torchmetrics.classification.MultilabelNegativePredictiveValue` for the specific details of each + argument influence and examples. + + Legacy Example: + >>> from torch import tensor + >>> preds = tensor([2, 0, 2, 1]) + >>> target = tensor([1, 1, 2, 0]) + >>> nvp = NegativePredictiveValue(task="multiclass", average='macro', num_classes=3) + >>> nvp(preds, target) + tensor(0.6667) + >>> nvp = NegativePredictiveValue(task="multiclass", average='micro', num_classes=3) + >>> nvp(preds, target) + tensor(0.6250) + + """ + + def __new__( # type: ignore[misc] + cls: Type["NegativePredictiveValue"], + task: Literal["binary", "multiclass", "multilabel"], + threshold: float = 0.5, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + """Initialize task metric.""" + task = ClassificationTask.from_str(task) + assert multidim_average is not None # noqa: S101 # needed for mypy + kwargs.update({ + "multidim_average": multidim_average, + "ignore_index": ignore_index, + "validate_args": validate_args, + }) + if task == ClassificationTask.BINARY: + return BinaryNegativePredictiveValue(threshold, **kwargs) + if task == ClassificationTask.MULTICLASS: + if not isinstance(num_classes, int): + raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") + if not isinstance(top_k, int): + raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") + return MulticlassNegativePredictiveValue(num_classes, top_k, average, **kwargs) + if task == ClassificationTask.MULTILABEL: + if not isinstance(num_labels, int): + raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") + return MultilabelNegativePredictiveValue(num_labels, threshold, average, **kwargs) + raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 7de7f261867..f76d907c907 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -36,6 +36,7 @@ hinge_loss, jaccard_index, matthews_corrcoef, + negative_predictive_value, precision, precision_at_fixed_recall, precision_recall_curve, @@ -177,6 +178,7 @@ "mean_squared_log_error", "minkowski_distance", "multiscale_structural_similarity_index_measure", + "negative_predictive_value", "normalized_root_mean_squared_error", "pairwise_cosine_similarity", "pairwise_euclidean_distance", diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index faf523844bc..73ff9fcc1ea 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -77,6 +77,12 @@ multiclass_matthews_corrcoef, multilabel_matthews_corrcoef, ) +from torchmetrics.functional.classification.negative_predictive_value import ( + binary_negative_predictive_value, + multiclass_negative_predictive_value, + multilabel_negative_predictive_value, + negative_predictive_value, +) from torchmetrics.functional.classification.precision_fixed_recall import ( binary_precision_at_fixed_recall, multiclass_precision_at_fixed_recall, @@ -234,4 +240,8 @@ "demographic_parity", "equal_opportunity", "precision_at_fixed_recall", + "binary_negative_predictive_value", + "multiclass_negative_predictive_value", + "multilabel_negative_predictive_value", + "negative_predictive_value", ] diff --git a/src/torchmetrics/functional/classification/negative_predictive_value.py b/src/torchmetrics/functional/classification/negative_predictive_value.py new file mode 100644 index 00000000000..65297e0c7e6 --- /dev/null +++ b/src/torchmetrics/functional/classification/negative_predictive_value.py @@ -0,0 +1,419 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.classification.stat_scores import ( + _binary_stat_scores_arg_validation, + _binary_stat_scores_format, + _binary_stat_scores_tensor_validation, + _binary_stat_scores_update, + _multiclass_stat_scores_arg_validation, + _multiclass_stat_scores_format, + _multiclass_stat_scores_tensor_validation, + _multiclass_stat_scores_update, + _multilabel_stat_scores_arg_validation, + _multilabel_stat_scores_format, + _multilabel_stat_scores_tensor_validation, + _multilabel_stat_scores_update, +) +from torchmetrics.utilities.compute import _adjust_weights_safe_divide, _safe_divide +from torchmetrics.utilities.enums import ClassificationTask + + +def _negative_predictive_value_reduce( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], + multidim_average: Literal["global", "samplewise"] = "global", + multilabel: bool = False, + top_k: int = 1, + zero_division: float = 0, +) -> Tensor: + """Reduction logic for negative predictive value.""" + if average == "binary": + return _safe_divide(tn, tn + fn, zero_division) + if average == "micro": + tn = tn.sum(dim=0 if multidim_average == "global" else 1) + fn = fn.sum(dim=0 if multidim_average == "global" else 1) + return _safe_divide(tn, tn + fn, zero_division) + score = _safe_divide(tn, tn + fn, zero_division) + return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k) + + +def binary_negative_predictive_value( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + zero_division: float = 0, +) -> Tensor: + r"""Compute `Negative Predictive Value`_ for binary tasks. + + .. math:: \text{Negative Predictive Value} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives + respectively. The metric is only proper defined when :math:`\text{TN} + \text{FP} \neq 0`. If this case is + encountered a score of 0 is returned. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torch import tensor + >>> from torchmetrics.functional.classification import binary_negative_predictive_value + >>> target = tensor([0, 1, 0, 1, 0, 1]) + >>> preds = tensor([0, 0, 1, 1, 0, 1]) + >>> binary_negative_predictive_value(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import binary_negative_predictive_value + >>> target = tensor([0, 1, 0, 1, 0, 1]) + >>> preds = tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_negative_predictive_value(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import binary_negative_predictive_value + >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) + >>> binary_negative_predictive_value(preds, target, multidim_average='samplewise') + tensor([0.0000, 0.2500]) + + """ + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) + return _negative_predictive_value_reduce( + tp, fp, tn, fn, average="binary", multidim_average=multidim_average, zero_division=zero_division + ) + + +def multiclass_negative_predictive_value( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + zero_division: float = 0, +) -> Tensor: + r"""Compute `Negative Predictive Value`_ for multiclass tasks. + + .. math:: \text{Negative Predictive Value} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives + respectively. The metric is only proper defined when :math:`\text{TN} + \text{FP} \neq 0`. If this case is + encountered a score of 0 is returned. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifying the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculate statistics for each label and compute a weighted average using their support + - ``"none"`` or ``None``: Calculate statistics for each label and apply no reduction + + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: Bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torch import tensor + >>> from torchmetrics.functional.classification import multiclass_negative_predictive_value + >>> target = tensor([2, 1, 0, 0]) + >>> preds = tensor([2, 1, 0, 1]) + >>> multiclass_negative_predictive_value(preds, target, num_classes=3) + tensor(0.8889) + >>> multiclass_negative_predictive_value(preds, target, num_classes=3, average=None) + tensor([0.6667, 1.0000, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multiclass_negative_predictive_value + >>> target = tensor([2, 1, 0, 0]) + >>> preds = tensor([[0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13]]) + >>> multiclass_negative_predictive_value(preds, target, num_classes=3) + tensor(0.8889) + >>> multiclass_negative_predictive_value(preds, target, num_classes=3, average=None) + tensor([0.6667, 1.0000, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multiclass_negative_predictive_value + >>> target = tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_negative_predictive_value(preds, target, num_classes=3, multidim_average='samplewise') + tensor([0.7833, 0.6556]) + >>> multiclass_negative_predictive_value( + ... preds, target, num_classes=3, multidim_average='samplewise', average=None + ... ) + tensor([[1.0000, 0.6000, 0.7500], + [0.8000, 0.5000, 0.6667]]) + + """ + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update( + preds, target, num_classes, top_k, average, multidim_average, ignore_index + ) + return _negative_predictive_value_reduce( + tp, fp, tn, fn, average=average, multidim_average=multidim_average, top_k=top_k, zero_division=zero_division + ) + + +def multilabel_negative_predictive_value( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + zero_division: float = 0, +) -> Tensor: + r"""Compute `Negative Predictive Value`_ for multilabel tasks. + + .. math:: \text{Negative Predictive Value} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives + respectively. The metric is only proper defined when :math:`\text{TN} + \text{FP} \neq 0`. If this case is + encountered a score of 0 is returned. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Additionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifying the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculate statistics for each label and compute a weighted average using their support + - ``"none"`` or ``None``: Calculate statistics for each label and apply no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: Bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torch import tensor + >>> from torchmetrics.functional.classification import multilabel_negative_predictive_value + >>> target = tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_negative_predictive_value(preds, target, num_labels=3) + tensor(0.5000) + >>> multilabel_negative_predictive_value(preds, target, num_labels=3, average=None) + tensor([1.0000, 0.5000, 0.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional.classification import multilabel_negative_predictive_value + >>> target = tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_negative_predictive_value(preds, target, num_labels=3) + tensor(0.5000) + >>> multilabel_negative_predictive_value(preds, target, num_labels=3, average=None) + tensor([1.0000, 0.5000, 0.0000]) + + Example (multidim tensors): + >>> from torchmetrics.functional.classification import multilabel_negative_predictive_value + >>> target = tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = tensor([[[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]]]) + >>> multilabel_negative_predictive_value(preds, target, num_labels=3, multidim_average='samplewise') + tensor([0.0000, 0.1667]) + >>> multilabel_negative_predictive_value( + ... preds, target, num_labels=3, multidim_average='samplewise', average=None + ... ) + tensor([[0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.5000]]) + + """ + if validate_args: + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) + return _negative_predictive_value_reduce( + tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True, zero_division=zero_division + ) + + +def negative_predictive_value( + preds: Tensor, + target: Tensor, + task: Literal["binary", "multiclass", "multilabel"], + threshold: float = 0.5, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = 1, + ignore_index: Optional[int] = None, + validate_args: bool = True, + zero_division: float = 0, +) -> Tensor: + r"""Compute `Negative Predictive Value`_. + + .. math:: \text{Negative Predictive Value} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and false positives + respectively. The metric is only proper defined when :math:`\text{TN} + \text{FP} \neq 0`. If this case is + encountered a score of 0 is returned. + + This function is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :func:`~torchmetrics.functional.classification.binary_negative_predictive_value`, + :func:`~torchmetrics.functional.classification.multiclass_negative_predictive_value` and + :func:`~torchmetrics.functional.classification.multilabel_negative_predictive_value` for the specific + details of each argument influence and examples. + + LegacyExample: + >>> from torch import tensor + >>> preds = tensor([2, 0, 2, 1]) + >>> target = tensor([1, 1, 2, 0]) + >>> negative_predictive_value(preds, target, task="multiclass", average='macro', num_classes=3) + tensor(0.6667) + >>> negative_predictive_value(preds, target, task="multiclass", average='micro', num_classes=3) + tensor(0.6250) + + """ + task = ClassificationTask.from_str(task) + assert multidim_average is not None # noqa: S101 # needed for mypy + if task == ClassificationTask.BINARY: + return binary_negative_predictive_value( + preds, target, threshold, multidim_average, ignore_index, validate_args, zero_division + ) + if task == ClassificationTask.MULTICLASS: + if not isinstance(num_classes, int): + raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") + if not isinstance(top_k, int): + raise ValueError(f"`top_k` is expected to be `int` but `{type(top_k)} was passed.`") + return multiclass_negative_predictive_value( + preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args, zero_division + ) + if task == ClassificationTask.MULTILABEL: + if not isinstance(num_labels, int): + raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") + return multilabel_negative_predictive_value( + preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args, zero_division + ) + raise ValueError(f"Not handled value: {task}") diff --git a/tests/unittests/classification/test_negative_predictive_value.py b/tests/unittests/classification/test_negative_predictive_value.py new file mode 100644 index 00000000000..464884ca82a --- /dev/null +++ b/tests/unittests/classification/test_negative_predictive_value.py @@ -0,0 +1,600 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import numpy as np +import pytest +import torch +from scipy.special import expit as sigmoid +from sklearn.metrics import confusion_matrix as sk_confusion_matrix +from torch import Tensor, tensor +from torchmetrics.classification.negative_predictive_value import ( + BinaryNegativePredictiveValue, + MulticlassNegativePredictiveValue, + MultilabelNegativePredictiveValue, + NegativePredictiveValue, +) +from torchmetrics.functional.classification.negative_predictive_value import ( + binary_negative_predictive_value, + multiclass_negative_predictive_value, + multilabel_negative_predictive_value, +) +from torchmetrics.metric import Metric + +from unittests import NUM_CLASSES, THRESHOLD +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases + +seed_all(42) + + +def _calc_negative_predictive_value(tn, fn): + """Safely calculate negative_predictive_value.""" + denom = tn + fn + if np.isscalar(tn): + denom = 1.0 if denom == 0 else denom + else: + denom[denom == 0] = 1.0 + return tn / denom + + +def _reference_negative_predictive_value_binary(preds, target, ignore_index, multidim_average): + if multidim_average == "global": + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + else: + preds = preds.numpy() + target = target.numpy() + + if np.issubdtype(preds.dtype, np.floating): + if not ((preds > 0) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + + if multidim_average == "global": + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + tn, _, fn, _ = sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1]).ravel() + return _calc_negative_predictive_value(tn, fn) + + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + tn, _, fn, _ = sk_confusion_matrix(y_true=true, y_pred=pred, labels=[0, 1]).ravel() + res.append(_calc_negative_predictive_value(tn, fn)) + return np.stack(res) + + +@pytest.mark.parametrize("inputs", _binary_cases) +class TestBinaryNegativePredictiveValue(MetricTester): + """Test class for `BinaryNegativePredictiveValue` metric.""" + + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_binary_negative_predictive_value(self, ddp, inputs, ignore_index, multidim_average): + """Test class implementation of metric.""" + preds, target = inputs + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryNegativePredictiveValue, + reference_metric=partial( + _reference_negative_predictive_value_binary, + ignore_index=ignore_index, + multidim_average=multidim_average, + ), + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, + ) + + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + def test_binary_negative_predictive_value_functional(self, inputs, ignore_index, multidim_average): + """Test functional implementation of metric.""" + preds, target = inputs + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_negative_predictive_value, + reference_metric=partial( + _reference_negative_predictive_value_binary, + ignore_index=ignore_index, + multidim_average=multidim_average, + ), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + }, + ) + + def test_binary_negative_predictive_value_differentiability(self, inputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + preds, target = inputs + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryNegativePredictiveValue, + metric_functional=binary_negative_predictive_value, + metric_args={"threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_negative_predictive_value_dtype_cpu(self, inputs, dtype): + """Test dtype support of the metric on CPU.""" + preds, target = inputs + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryNegativePredictiveValue, + metric_functional=binary_negative_predictive_value, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_negative_predictive_value_dtype_gpu(self, inputs, dtype): + """Test dtype support of the metric on GPU.""" + preds, target = inputs + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryNegativePredictiveValue, + metric_functional=binary_negative_predictive_value, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + + +def _reference_negative_predictive_value_multiclass_global(preds, target, ignore_index, average): + preds = preds.numpy().flatten() + target = target.numpy().flatten() + + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) + tp = np.diag(confmat) + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + + if average == "micro": + return _calc_negative_predictive_value(tn.sum(), fn.sum()) + + res = _calc_negative_predictive_value(tn, fn) + if average == "macro": + res = res[(np.bincount(preds, minlength=NUM_CLASSES) + np.bincount(target, minlength=NUM_CLASSES)) != 0.0] + return res.mean(0) + if average == "weighted": + w = tp + fn + return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + if average is None or average == "none": + return res + return None + + +def _reference_negative_predictive_value_multiclass_local(preds, target, ignore_index, average): + preds = preds.numpy() + target = target.numpy() + + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + confmat = sk_confusion_matrix(y_true=true, y_pred=pred, labels=list(range(NUM_CLASSES))) + tp = np.diag(confmat) + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + if average == "micro": + res.append(_calc_negative_predictive_value(tn.sum(), fn.sum())) + + r = _calc_negative_predictive_value(tn, fn) + if average == "macro": + r = r[(np.bincount(pred, minlength=NUM_CLASSES) + np.bincount(true, minlength=NUM_CLASSES)) != 0.0] + res.append(r.mean(0) if len(r) > 0 else 0.0) + elif average == "weighted": + w = tp + fn + res.append((r * (w / w.sum()).reshape(-1, 1)).sum(0)) + elif average is None or average == "none": + res.append(r) + return np.stack(res, 0) + + +def _reference_negative_predictive_value_multiclass(preds, target, ignore_index, multidim_average, average): + if preds.ndim == target.ndim + 1: + preds = torch.argmax(preds, 1) + if multidim_average == "global": + return _reference_negative_predictive_value_multiclass_global(preds, target, ignore_index, average) + return _reference_negative_predictive_value_multiclass_local(preds, target, ignore_index, average) + + +@pytest.mark.parametrize("inputs", _multiclass_cases) +class TestMulticlassNegativePredictiveValue(MetricTester): + """Test class for `MulticlassNegativePredictiveValue` metric.""" + + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_multiclass_negative_predictive_value(self, ddp, inputs, ignore_index, multidim_average, average): + """Test class implementation of metric.""" + preds, target = inputs + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassNegativePredictiveValue, + reference_metric=partial( + _reference_negative_predictive_value_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multiclass_negative_predictive_value_functional(self, inputs, ignore_index, multidim_average, average): + """Test functional implementation of metric.""" + preds, target = inputs + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_negative_predictive_value, + reference_metric=partial( + _reference_negative_predictive_value_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, + ) + + def test_multiclass_negative_predictive_value_differentiability(self, inputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + preds, target = inputs + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassNegativePredictiveValue, + metric_functional=multiclass_negative_predictive_value, + metric_args={"num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_negative_predictive_value_dtype_cpu(self, inputs, dtype): + """Test dtype support of the metric on CPU.""" + preds, target = inputs + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassNegativePredictiveValue, + metric_functional=multiclass_negative_predictive_value, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_negative_predictive_value_dtype_gpu(self, inputs, dtype): + """Test dtype support of the metric on GPU.""" + preds, target = inputs + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassNegativePredictiveValue, + metric_functional=multiclass_negative_predictive_value, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + +_mc_k_target = tensor([0, 1, 2]) +_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) + + +@pytest.mark.parametrize( + ("k", "preds", "target", "average", "expected_spec"), + [ + (1, _mc_k_preds, _mc_k_target, "micro", tensor(5 / 6)), + (2, _mc_k_preds, _mc_k_target, "micro", tensor(1)), + ], +) +def test_top_k(k: int, preds: Tensor, target: Tensor, average: str, expected_spec: Tensor): + """A simple test to check that top_k works as expected.""" + class_metric = MulticlassNegativePredictiveValue(top_k=k, average=average, num_classes=3) + class_metric.update(preds, target) + + assert torch.equal(class_metric.compute(), expected_spec) + assert torch.equal( + multiclass_negative_predictive_value(preds, target, top_k=k, average=average, num_classes=3), expected_spec + ) + + +def _reference_negative_predictive_value_multilabel_global(preds, target, ignore_index, average): + tns, fns = [], [] + for i in range(preds.shape[1]): + p, t = preds[:, i].flatten(), target[:, i].flatten() + if ignore_index is not None: + idx = t == ignore_index + t = t[~idx] + p = p[~idx] + tn, _, fn, _ = sk_confusion_matrix(t, p, labels=[0, 1]).ravel() + tns.append(tn) + fns.append(fn) + + tn = np.array(tns) + fn = np.array(fns) + if average == "micro": + return _calc_negative_predictive_value(tn.sum(), fn.sum()) + + res = _calc_negative_predictive_value(tn, fn) + if average == "macro": + return res.mean(0) + if average == "weighted": + w = res[:, 0] + res[:, 3] + return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + if average is None or average == "none": + return res + return None + + +def _reference_negative_predictive_value_multilabel_local(preds, target, ignore_index, average): + negative_predictive_value = [] + for i in range(preds.shape[0]): + tns, fns = [], [] + for j in range(preds.shape[1]): + pred, true = preds[i, j], target[i, j] + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + tn, _, fn, _ = sk_confusion_matrix(true, pred, labels=[0, 1]).ravel() + tns.append(tn) + fns.append(fn) + tn = np.array(tns) + fn = np.array(fns) + if average == "micro": + negative_predictive_value.append(_calc_negative_predictive_value(tn.sum(), fn.sum())) + else: + negative_predictive_value.append(_calc_negative_predictive_value(tn, fn)) + + res = np.stack(negative_predictive_value, 0) + if average == "micro" or average is None or average == "none": + return res + if average == "macro": + return res.mean(-1) + if average == "weighted": + w = res[:, 0, :] + res[:, 3, :] + return (res * (w / w.sum())[:, np.newaxis]).sum(-1) + if average is None or average == "none": + return np.moveaxis(res, 1, -1) + return None + + +def _reference_negative_predictive_value_multilabel(preds, target, ignore_index, multidim_average, average): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((preds > 0) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + if multidim_average == "global": + return _reference_negative_predictive_value_multilabel_global(preds, target, ignore_index, average) + return _reference_negative_predictive_value_multilabel_local(preds, target, ignore_index, average) + + +@pytest.mark.parametrize("inputs", _multilabel_cases) +class TestMultilabelNegativePredictiveValue(MetricTester): + """Test class for `MultilabelNegativePredictiveValue` metric.""" + + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multilabel_negative_predictive_value(self, ddp, inputs, ignore_index, multidim_average, average): + """Test class implementation of metric.""" + preds, target = inputs + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelNegativePredictiveValue, + reference_metric=partial( + _reference_negative_predictive_value_multilabel, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multilabel_negative_predictive_value_functional(self, inputs, ignore_index, multidim_average, average): + """Test functional implementation of metric.""" + preds, target = inputs + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multilabel_negative_predictive_value, + reference_metric=partial( + _reference_negative_predictive_value_multilabel, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) + + def test_multilabel_negative_predictive_value_differentiability(self, inputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + preds, target = inputs + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MultilabelNegativePredictiveValue, + metric_functional=multilabel_negative_predictive_value, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_negative_predictive_value_dtype_cpu(self, inputs, dtype): + """Test dtype support of the metric on CPU.""" + preds, target = inputs + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelNegativePredictiveValue, + metric_functional=multilabel_negative_predictive_value, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_negative_predictive_value_dtype_gpu(self, inputs, dtype): + """Test dtype support of the metric on GPU.""" + preds, target = inputs + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelNegativePredictiveValue, + metric_functional=multilabel_negative_predictive_value, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + + +def test_corner_cases(): + """Test corner cases for negative predictive value metric.""" + # simulate the output of a perfect predictor (i.e. preds == target) + target = torch.tensor([0, 1, 2, 0, 1, 2]) + preds = target + + metric = MulticlassNegativePredictiveValue(num_classes=3, average="none", ignore_index=0) + res = metric(preds, target) + assert torch.allclose(res, torch.tensor([1.0, 1.0, 1.0])) + + metric = MulticlassNegativePredictiveValue(num_classes=3, average="macro", ignore_index=0) + res = metric(preds, target) + assert res == 1.0 + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryNegativePredictiveValue, {"task": "binary"}), + (MulticlassNegativePredictiveValue, {"task": "multiclass", "num_classes": 3}), + (MultilabelNegativePredictiveValue, {"task": "multilabel", "num_labels": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=NegativePredictiveValue): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric)