Skip to content

Commit

Permalink
Add ClassificationTask Enum (#1479)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored Feb 7, 2023
1 parent 82f2e45 commit b95d482
Show file tree
Hide file tree
Showing 43 changed files with 310 additions and 297 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `classes` to output from `MAP` metric ([#1419](https://github.com/Lightning-AI/metrics/pull/1419))


- Add `ClassificationTask` Enum and use in metrics ([#1479](https://github.com/Lightning-AI/metrics/pull/1479))


### Changed

- Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370))
Expand All @@ -31,6 +34,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Raise exception for invalid kwargs in Metric base class ([#1427](https://github.com/Lightning-AI/metrics/pull/1427))


- Extend `EnumStr` raising `ValueError` for invalid value ([#1479](https://github.com/Lightning-AI/metrics/pull/1479))


### Deprecated

-
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ numpy>=1.17.2
torch>=1.8.1
typing-extensions; python_version < '3.9'
packaging # hotfix for utils, can be dropped with lit-utils >=0.5
lightning-utilities>=0.4.1
lightning-utilities>=0.5.0
11 changes: 5 additions & 6 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from torchmetrics.functional.classification.accuracy import _accuracy_reduce
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

Expand Down Expand Up @@ -490,18 +491,16 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
task = ClassificationTask.from_str(task)
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryAccuracy(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassAccuracy(num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelAccuracy(num_labels, threshold, average, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
11 changes: 5 additions & 6 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask


class BinaryAUROC(BinaryPrecisionRecallCurve):
Expand Down Expand Up @@ -352,15 +353,13 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
task = ClassificationTask.from_str(task)
kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryAUROC(max_fpr, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassAUROC(num_classes, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelAUROC(num_labels, average, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
11 changes: 5 additions & 6 deletions src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask


class BinaryAveragePrecision(BinaryPrecisionRecallCurve):
Expand Down Expand Up @@ -356,15 +357,13 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
task = ClassificationTask.from_str(task)
kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryAveragePrecision(**kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassAveragePrecision(num_classes, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelAveragePrecision(num_labels, average, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
9 changes: 4 additions & 5 deletions src/torchmetrics/classification/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel


class BinaryCalibrationError(Metric):
Expand Down Expand Up @@ -267,12 +268,10 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
task = ClassificationTaskNoMultilabel.from_str(task)
kwargs.update({"n_bins": n_bins, "norm": norm, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTaskNoMultilabel.BINARY:
return BinaryCalibrationError(**kwargs)
if task == "multiclass":
if task == ClassificationTaskNoMultilabel.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassCalibrationError(num_classes, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
9 changes: 4 additions & 5 deletions src/torchmetrics/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_multiclass_cohen_kappa_arg_validation,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel


class BinaryCohenKappa(BinaryConfusionMatrix):
Expand Down Expand Up @@ -221,12 +222,10 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
task = ClassificationTaskNoMultilabel.from_str(task)
kwargs.update({"weights": weights, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTaskNoMultilabel.BINARY:
return BinaryCohenKappa(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTaskNoMultilabel.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassCohenKappa(num_classes, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
11 changes: 5 additions & 6 deletions src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_multilabel_confusion_matrix_update,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _PLOT_OUT_TYPE, plot_confusion_matrix

Expand Down Expand Up @@ -397,15 +398,13 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
task = ClassificationTask.from_str(task)
kwargs.update({"normalize": normalize, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryConfusionMatrix(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassConfusionMatrix(num_classes, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelConfusionMatrix(num_labels, threshold, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTaskNoBinary


class MulticlassExactMatch(Metric):
Expand Down Expand Up @@ -288,13 +289,13 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
task = ClassificationTaskNoBinary.from_str(task)
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "multiclass":
if task == ClassificationTaskNoBinary.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassExactMatch(num_classes, **kwargs)
if task == "multilabel":
if task == ClassificationTaskNoBinary.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelExactMatch(num_labels, threshold, **kwargs)
raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}")
17 changes: 8 additions & 9 deletions src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_multilabel_fbeta_score_arg_validation,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryFBetaScore(BinaryStatScores):
Expand Down Expand Up @@ -729,13 +730,13 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryFBetaScore(beta, threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down Expand Up @@ -776,19 +777,17 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
task = ClassificationTask.from_str(task)
assert multidim_average is not None
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryF1Score(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassF1Score(num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelF1Score(num_labels, threshold, average, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
12 changes: 5 additions & 7 deletions src/torchmetrics/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores
from torchmetrics.functional.classification.hamming import _hamming_distance_reduce
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryHammingDistance(BinaryStatScores):
Expand Down Expand Up @@ -344,20 +345,17 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:

task = ClassificationTask.from_str(task)
assert multidim_average is not None
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryHammingDistance(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassHammingDistance(num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelHammingDistance(num_labels, threshold, average, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
9 changes: 4 additions & 5 deletions src/torchmetrics/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_multiclass_hinge_loss_update,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel


class BinaryHingeLoss(Metric):
Expand Down Expand Up @@ -252,12 +253,10 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
task = ClassificationTaskNoMultilabel.from_str(task)
kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTaskNoMultilabel.BINARY:
return BinaryHingeLoss(squared, **kwargs)
if task == "multiclass":
if task == ClassificationTaskNoMultilabel.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
11 changes: 5 additions & 6 deletions src/torchmetrics/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_multilabel_jaccard_index_arg_validation,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryJaccardIndex(BinaryConfusionMatrix):
Expand Down Expand Up @@ -295,15 +296,13 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
task = ClassificationTask.from_str(task)
kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryJaccardIndex(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassJaccardIndex(num_classes, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
11 changes: 5 additions & 6 deletions src/torchmetrics/classification/matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix
from torchmetrics.functional.classification.matthews_corrcoef import _matthews_corrcoef_reduce
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryMatthewsCorrCoef(BinaryConfusionMatrix):
Expand Down Expand Up @@ -237,15 +238,13 @@ def __new__(
validate_args: bool = True,
**kwargs: Any,
) -> Metric:
task = ClassificationTask.from_str(task)
kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryMatthewsCorrCoef(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassMatthewsCorrCoef(num_classes, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs)
raise ValueError(
f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}"
)
Loading

0 comments on commit b95d482

Please sign in to comment.