Skip to content

Commit

Permalink
Add zero_division option to the precision, recall, f1, fbeta. (#2198)
Browse files Browse the repository at this point in the history
* Add support of zero_division parameter

* fix overlooked

* Fix type error

* Fix type error

* Fix missing comma

* Doc fix wrong math expression

* Fixed StatScores to have zero_division

* fix missing zero_division arg

* fix device mismatch

* use scikit-learn 1.4.0

* fix scikit-learn min ver

* fix for new sklearn version

* fix scikit-learn requirements

* fix incorrect requirements condition

* fix test code to pass in multiple sklearn versions

* changelog

* better docstring

* add jaccardindex

* fix tests

* skip for old sklearn versions

---------

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
5 people authored May 3, 2024
1 parent d9add3d commit 335ebe6
Showing 17 changed files with 606 additions and 158 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for `torch.float` weighted networks for FID and KID calculations ([#2483](https://github.com/Lightning-AI/torchmetrics/pull/2483))


- Added `zero_division` argument to selected classification metrics ([#2198](https://github.com/Lightning-AI/torchmetrics/pull/2198))


### Changed

- Made `__getattr__` and `__setattr__` of `ClasswiseWrapper` more general ([#2424](https://github.com/Lightning-AI/torchmetrics/pull/2424))
3 changes: 2 additions & 1 deletion requirements/_tests.txt
Original file line number Diff line number Diff line change
@@ -15,5 +15,6 @@ pyGithub ==2.3.0
fire <=0.6.0

cloudpickle >1.3, <=3.0.0
scikit-learn >=1.1.1, <1.4.0
scikit-learn >=1.1.1, <1.3.0; python_version < "3.9"
scikit-learn >=1.4.0, <1.5.0; python_version >= "3.9"
cachier ==3.0.0
90 changes: 76 additions & 14 deletions src/torchmetrics/classification/f_beta.py

Large diffs are not rendered by default.

20 changes: 17 additions & 3 deletions src/torchmetrics/classification/jaccard.py
Original file line number Diff line number Diff line change
@@ -65,6 +65,8 @@ class BinaryJaccardIndex(BinaryConfusionMatrix):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division:
Value to replace when there is a division by zero. Should be `0` or `1`.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example (preds is int tensor):
@@ -97,15 +99,17 @@ def __init__(
threshold: float = 0.5,
ignore_index: Optional[int] = None,
validate_args: bool = True,
zero_division: float = 0,
**kwargs: Any,
) -> None:
super().__init__(
threshold=threshold, ignore_index=ignore_index, normalize=None, validate_args=validate_args, **kwargs
)
self.zero_division = zero_division

def compute(self) -> Tensor:
"""Compute metric."""
return _jaccard_index_reduce(self.confmat, average="binary")
return _jaccard_index_reduce(self.confmat, average="binary", zero_division=self.zero_division)

def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
@@ -187,6 +191,8 @@ class MulticlassJaccardIndex(MulticlassConfusionMatrix):
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division:
Value to replace when there is a division by zero. Should be `0` or `1`.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example (pred is integer tensor):
@@ -224,6 +230,7 @@ def __init__(
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
ignore_index: Optional[int] = None,
validate_args: bool = True,
zero_division: float = 0,
**kwargs: Any,
) -> None:
super().__init__(
@@ -233,10 +240,13 @@ def __init__(
_multiclass_jaccard_index_arg_validation(num_classes, ignore_index, average)
self.validate_args = validate_args
self.average = average
self.zero_division = zero_division

def compute(self) -> Tensor:
"""Compute metric."""
return _jaccard_index_reduce(self.confmat, average=self.average, ignore_index=self.ignore_index)
return _jaccard_index_reduce(
self.confmat, average=self.average, ignore_index=self.ignore_index, zero_division=self.zero_division
)

def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
@@ -319,6 +329,8 @@ class MultilabelJaccardIndex(MultilabelConfusionMatrix):
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division:
Value to replace when there is a division by zero. Should be `0` or `1`.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example (preds is int tensor):
@@ -354,6 +366,7 @@ def __init__(
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
ignore_index: Optional[int] = None,
validate_args: bool = True,
zero_division: float = 0,
**kwargs: Any,
) -> None:
super().__init__(
@@ -368,10 +381,11 @@ def __init__(
_multilabel_jaccard_index_arg_validation(num_labels, threshold, ignore_index, average)
self.validate_args = validate_args
self.average = average
self.zero_division = zero_division

def compute(self) -> Tensor:
"""Compute metric."""
return _jaccard_index_reduce(self.confmat, average=self.average)
return _jaccard_index_reduce(self.confmat, average=self.average, zero_division=self.zero_division)

def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
88 changes: 71 additions & 17 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,9 @@

from torchmetrics.classification.base import _ClassificationTaskWrapper
from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores
from torchmetrics.functional.classification.precision_recall import _precision_recall_reduce
from torchmetrics.functional.classification.precision_recall import (
_precision_recall_reduce,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
@@ -42,7 +44,7 @@ class BinaryPrecision(BinaryStatScores):
Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
encountered a score of 0 is returned.
encountered a score of `zero_division` (0 or 1, default is 0) is returned.
As input to ``forward`` and ``update`` the metric accepts the following input:
@@ -73,6 +75,7 @@ class BinaryPrecision(BinaryStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
Example (preds is int tensor):
>>> from torch import tensor
@@ -112,7 +115,14 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average
"precision",
tp,
fp,
tn,
fn,
average="binary",
multidim_average=self.multidim_average,
zero_division=self.zero_division,
)

def plot(
@@ -165,8 +175,8 @@ class MulticlassPrecision(MulticlassStatScores):
Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be
affected in turn.
encountered for any class, the metric for that class will be set to `zero_division` (0 or 1, default is 0) and
the overall metric may therefore be affected in turn.
As input to ``forward`` and ``update`` the metric accepts the following input:
@@ -217,6 +227,7 @@ class MulticlassPrecision(MulticlassStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
Example (preds is int tensor):
>>> from torch import tensor
@@ -269,7 +280,15 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k
"precision",
tp,
fp,
tn,
fn,
average=self.average,
multidim_average=self.multidim_average,
top_k=self.top_k,
zero_division=self.zero_division,
)

def plot(
@@ -322,8 +341,8 @@ class MultilabelPrecision(MultilabelStatScores):
Where :math:`\text{TP}` and :math:`\text{FP}` represent the number of true positives and false positives
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FP} \neq 0`. If this case is
encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be
affected in turn.
encountered for any label, the metric for that label will be set to `zero_division` (0 or 1, default is 0) and
the overall metric may therefore be affected in turn.
As input to ``forward`` and ``update`` the metric accepts the following input:
@@ -373,6 +392,7 @@ class MultilabelPrecision(MultilabelStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FP} = 0`.
Example (preds is int tensor):
>>> from torch import tensor
@@ -423,7 +443,15 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
"precision",
tp,
fp,
tn,
fn,
average=self.average,
multidim_average=self.multidim_average,
multilabel=True,
zero_division=self.zero_division,
)

def plot(
@@ -476,7 +504,7 @@ class BinaryRecall(BinaryStatScores):
Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
encountered a score of 0 is returned.
encountered a score of `zero_division` (0 or 1, default is 0) is returned.
As input to ``forward`` and ``update`` the metric accepts the following input:
@@ -507,6 +535,7 @@ class BinaryRecall(BinaryStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
Example (preds is int tensor):
>>> from torch import tensor
@@ -546,7 +575,14 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average
"recall",
tp,
fp,
tn,
fn,
average="binary",
multidim_average=self.multidim_average,
zero_division=self.zero_division,
)

def plot(
@@ -599,8 +635,8 @@ class MulticlassRecall(MulticlassStatScores):
Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
encountered for any class, the metric for that class will be set to 0 and the overall metric may therefore be
affected in turn.
encountered for any class, the metric for that class will be set to `zero_division` (0 or 1, default is 0) and
the overall metric may therefore be affected in turn.
As input to ``forward`` and ``update`` the metric accepts the following input:
@@ -650,6 +686,7 @@ class MulticlassRecall(MulticlassStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
Example (preds is int tensor):
>>> from torch import tensor
@@ -702,7 +739,15 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k
"recall",
tp,
fp,
tn,
fn,
average=self.average,
multidim_average=self.multidim_average,
top_k=self.top_k,
zero_division=self.zero_division,
)

def plot(
@@ -755,8 +800,8 @@ class MultilabelRecall(MultilabelStatScores):
Where :math:`\text{TP}` and :math:`\text{FN}` represent the number of true positives and false negatives
respectively. The metric is only proper defined when :math:`\text{TP} + \text{FN} \neq 0`. If this case is
encountered for any label, the metric for that label will be set to 0 and the overall metric may therefore be
affected in turn.
encountered for any label, the metric for that label will be set to `zero_division` (0 or 1, default is 0) and
the overall metric may therefore be affected in turn.
As input to ``forward`` and ``update`` the metric accepts the following input:
@@ -805,6 +850,7 @@ class MultilabelRecall(MultilabelStatScores):
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
zero_division: Should be `0` or `1`. The value returned when :math:`\text{TP} + \text{FN} = 0`.
Example (preds is int tensor):
>>> from torch import tensor
@@ -855,7 +901,15 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True
"recall",
tp,
fp,
tn,
fn,
average=self.average,
multidim_average=self.multidim_average,
multilabel=True,
zero_division=self.zero_division,
)

def plot(
16 changes: 13 additions & 3 deletions src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
@@ -169,13 +169,15 @@ def __init__(
validate_args: bool = True,
**kwargs: Any,
) -> None:
zero_division = kwargs.pop("zero_division", 0)
super(_AbstractStatScores, self).__init__(**kwargs)
if validate_args:
_binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index)
_binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
self.threshold = threshold
self.multidim_average = multidim_average
self.ignore_index = ignore_index
self.validate_args = validate_args
self.zero_division = zero_division

self._create_state(size=1, multidim_average=multidim_average)

@@ -313,15 +315,19 @@ def __init__(
validate_args: bool = True,
**kwargs: Any,
) -> None:
zero_division = kwargs.pop("zero_division", 0)
super(_AbstractStatScores, self).__init__(**kwargs)
if validate_args:
_multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index)
_multiclass_stat_scores_arg_validation(
num_classes, top_k, average, multidim_average, ignore_index, zero_division
)
self.num_classes = num_classes
self.top_k = top_k
self.average = average
self.multidim_average = multidim_average
self.ignore_index = ignore_index
self.validate_args = validate_args
self.zero_division = zero_division

self._create_state(
size=1 if (average == "micro" and top_k == 1) else num_classes, multidim_average=multidim_average
@@ -461,15 +467,19 @@ def __init__(
validate_args: bool = True,
**kwargs: Any,
) -> None:
zero_division = kwargs.pop("zero_division", 0)
super(_AbstractStatScores, self).__init__(**kwargs)
if validate_args:
_multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index)
_multilabel_stat_scores_arg_validation(
num_labels, threshold, average, multidim_average, ignore_index, zero_division
)
self.num_labels = num_labels
self.threshold = threshold
self.average = average
self.multidim_average = multidim_average
self.ignore_index = ignore_index
self.validate_args = validate_args
self.zero_division = zero_division

self._create_state(size=num_labels, multidim_average=multidim_average)

Loading

0 comments on commit 335ebe6

Please sign in to comment.