From b7a8dc77475a5a29fc9a5a8a9a5acc50d99f2165 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 3 Feb 2023 14:22:45 +0100 Subject: [PATCH 01/20] add enum --- src/torchmetrics/utilities/enums.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index c9672626727..fea1fe2aaac 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -81,3 +81,15 @@ class MDMCAverageMethod(EnumStr): GLOBAL = "global" SAMPLEWISE = "samplewise" + + +class ClassificationTask(EnumStr): + """Enum to represent the different tasks in classification metrics. + + >>> "binary" in list(ClassificationTask) + True + """ + + BINARY = "binary" + MULTICLASS = "multiclass" + MULTILABEL = "multilabel" From aee0f126bd9d91fdf51d97e9e0b1e34b7ef2f0ff Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 3 Feb 2023 14:23:18 +0100 Subject: [PATCH 02/20] add enum --- src/torchmetrics/classification/accuracy.py | 7 ++++--- src/torchmetrics/classification/auroc.py | 7 ++++--- .../classification/average_precision.py | 7 ++++--- .../classification/calibration_error.py | 5 +++-- src/torchmetrics/classification/cohen_kappa.py | 5 +++-- src/torchmetrics/classification/confusion_matrix.py | 7 ++++--- src/torchmetrics/classification/exact_match.py | 5 +++-- src/torchmetrics/classification/f_beta.py | 13 +++++++------ src/torchmetrics/classification/hamming.py | 7 ++++--- src/torchmetrics/classification/hinge.py | 5 +++-- src/torchmetrics/classification/jaccard.py | 7 ++++--- .../classification/matthews_corrcoef.py | 7 ++++--- src/torchmetrics/classification/precision_recall.py | 13 +++++++------ .../classification/precision_recall_curve.py | 7 ++++--- .../classification/recall_at_fixed_precision.py | 7 ++++--- src/torchmetrics/classification/roc.py | 7 ++++--- src/torchmetrics/classification/specificity.py | 7 ++++--- .../classification/specificity_at_sensitivity.py | 7 ++++--- src/torchmetrics/classification/stat_scores.py | 7 ++++--- .../functional/classification/accuracy.py | 7 ++++--- src/torchmetrics/functional/classification/auroc.py | 7 ++++--- .../functional/classification/average_precision.py | 7 ++++--- .../functional/classification/calibration_error.py | 5 +++-- .../functional/classification/cohen_kappa.py | 5 +++-- .../functional/classification/confusion_matrix.py | 7 ++++--- .../functional/classification/exact_match.py | 3 ++- .../functional/classification/f_beta.py | 13 +++++++------ .../functional/classification/hamming.py | 7 ++++--- src/torchmetrics/functional/classification/hinge.py | 5 +++-- .../functional/classification/jaccard.py | 7 ++++--- .../functional/classification/matthews_corrcoef.py | 7 ++++--- .../functional/classification/precision_recall.py | 13 +++++++------ .../classification/precision_recall_curve.py | 7 ++++--- .../classification/recall_at_fixed_precision.py | 7 ++++--- src/torchmetrics/functional/classification/roc.py | 7 ++++--- .../functional/classification/specificity.py | 7 ++++--- .../classification/specificity_at_sensitivity.py | 7 ++++--- .../functional/classification/stat_scores.py | 8 ++++---- 38 files changed, 155 insertions(+), 118 deletions(-) diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 6cb1704ab04..9547a979322 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -18,6 +18,7 @@ from torchmetrics.functional.classification.accuracy import _accuracy_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val @@ -453,13 +454,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryAccuracy(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassAccuracy(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelAccuracy(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 95be484f01b..6b531213f0e 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -31,6 +31,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryAUROC(BinaryPrecisionRecallCurve): @@ -353,12 +354,12 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryAUROC(max_fpr, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassAUROC(num_classes, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelAUROC(num_labels, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 5a78ce0401c..6cbb1492bb2 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -30,6 +30,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryAveragePrecision(BinaryPrecisionRecallCurve): @@ -357,12 +358,12 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryAveragePrecision(**kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassAveragePrecision(num_classes, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelAveragePrecision(num_labels, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 335ee4ae916..c8849dc479d 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -29,6 +29,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryCalibrationError(Metric): @@ -268,9 +269,9 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"n_bins": n_bins, "norm": norm, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryCalibrationError(**kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassCalibrationError(num_classes, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 87045f6c04d..73950596a4a 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -23,6 +23,7 @@ _multiclass_cohen_kappa_arg_validation, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryCohenKappa(BinaryConfusionMatrix): @@ -222,9 +223,9 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"weights": weights, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryCohenKappa(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassCohenKappa(num_classes, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 57c0e931f89..f88e23206ea 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -35,6 +35,7 @@ _multilabel_confusion_matrix_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _PLOT_OUT_TYPE, plot_confusion_matrix @@ -398,12 +399,12 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"normalize": normalize, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryConfusionMatrix(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassConfusionMatrix(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelConfusionMatrix(num_labels, threshold, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 4c0bd64768d..c86ff796bd4 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -32,6 +32,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class MulticlassExactMatch(Metric): @@ -291,10 +292,10 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassExactMatch(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelExactMatch(num_labels, threshold, **kwargs) raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}") diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 01c4b2f05e7..027c800ed3a 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -24,6 +24,7 @@ _multilabel_fbeta_score_arg_validation, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryFBetaScore(BinaryStatScores): @@ -729,13 +730,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryFBetaScore(beta, threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs) raise ValueError( @@ -780,13 +781,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryF1Score(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassF1Score(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelF1Score(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index c4da43ffbf8..c08798c3ec1 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -19,6 +19,7 @@ from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.hamming import _hamming_distance_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryHammingDistance(BinaryStatScores): @@ -349,13 +350,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryHammingDistance(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassHammingDistance(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelHammingDistance(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index c27b84090da..0f3798c892e 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -29,6 +29,7 @@ _multiclass_hinge_loss_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryHingeLoss(Metric): @@ -253,9 +254,9 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryHingeLoss(squared, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 637a862466c..589f75b840e 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -23,6 +23,7 @@ _multilabel_jaccard_index_arg_validation, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryJaccardIndex(BinaryConfusionMatrix): @@ -296,12 +297,12 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryJaccardIndex(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassJaccardIndex(num_classes, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index d049c061878..8886c02887f 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -19,6 +19,7 @@ from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.functional.classification.matthews_corrcoef import _matthews_corrcoef_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryMatthewsCorrCoef(BinaryConfusionMatrix): @@ -238,12 +239,12 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryMatthewsCorrCoef(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassMatthewsCorrCoef(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index daff2af6d3a..fc0f5767669 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -19,6 +19,7 @@ from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.precision_recall import _precision_recall_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryPrecision(BinaryStatScores): @@ -620,13 +621,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryPrecision(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassPrecision(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelPrecision(num_labels, threshold, average, **kwargs) raise ValueError( @@ -676,13 +677,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryRecall(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassRecall(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelRecall(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 386c994bff9..8e7070a16e1 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -37,6 +37,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryPrecisionRecallCurve(Metric): @@ -467,12 +468,12 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryPrecisionRecallCurve(**kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassPrecisionRecallCurve(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelPrecisionRecallCurve(num_labels, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/recall_at_fixed_precision.py b/src/torchmetrics/classification/recall_at_fixed_precision.py index 2f57eb05e8f..516761d62e3 100644 --- a/src/torchmetrics/classification/recall_at_fixed_precision.py +++ b/src/torchmetrics/classification/recall_at_fixed_precision.py @@ -31,6 +31,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryRecallAtFixedPrecision(BinaryPrecisionRecallCurve): @@ -323,14 +324,14 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryRecallAtFixedPrecision(min_precision, thresholds, ignore_index, validate_args, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassRecallAtFixedPrecision( num_classes, min_precision, thresholds, ignore_index, validate_args, **kwargs ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelRecallAtFixedPrecision( num_labels, min_precision, thresholds, ignore_index, validate_args, **kwargs diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 32c2aa5cbaf..5984fad6dd8 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -28,6 +28,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryROC(BinaryPrecisionRecallCurve): @@ -382,12 +383,12 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryROC(**kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassROC(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelROC(num_labels, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index f8a65ca2a21..4c249d43c7c 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -19,6 +19,7 @@ from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.specificity import _specificity_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinarySpecificity(BinaryStatScores): @@ -324,13 +325,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinarySpecificity(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassSpecificity(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelSpecificity(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/specificity_at_sensitivity.py b/src/torchmetrics/classification/specificity_at_sensitivity.py index f31935f9d0c..a9effa59d84 100644 --- a/src/torchmetrics/classification/specificity_at_sensitivity.py +++ b/src/torchmetrics/classification/specificity_at_sensitivity.py @@ -31,6 +31,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinarySpecificityAtSensitivity(BinaryPrecisionRecallCurve): @@ -327,14 +328,14 @@ def __new__( # type: ignore validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task == "binary": + if task == ClassificationTask.BINARY: return BinarySpecificityAtSensitivity(min_sensitivity, thresholds, ignore_index, validate_args, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassSpecificityAtSensitivity( num_classes, min_sensitivity, thresholds, ignore_index, validate_args, **kwargs ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelSpecificityAtSensitivity( num_labels, min_sensitivity, thresholds, ignore_index, validate_args, **kwargs diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 0bd18ae41d8..f2604cbd80c 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -36,6 +36,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class _AbstractStatScores(Metric): @@ -499,13 +500,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryStatScores(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassStatScores(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelStatScores(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index a6e08093a5b..fa36ccc92d1 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _accuracy_reduce( @@ -397,15 +398,15 @@ def accuracy( tensor(0.6667) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_accuracy(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_accuracy( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_accuracy( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index ddfccb392c4..875cd0846b4 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -38,6 +38,7 @@ ) from torchmetrics.utilities.compute import _auc_compute_without_check, _safe_divide from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.prints import rank_zero_warn @@ -450,12 +451,12 @@ def auroc( >>> auroc(preds, target, task='multiclass', num_classes=3) tensor(0.7778) """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_auroc(preds, target, max_fpr, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_auroc(preds, target, num_classes, average, thresholds, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_auroc(preds, target, num_labels, average, thresholds, ignore_index, validate_args) raise ValueError( diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 95507e6e2ee..4c8d84ab9b8 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -36,6 +36,7 @@ ) from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.prints import rank_zero_warn @@ -438,14 +439,14 @@ def average_precision( >>> average_precision(pred, target, task="multiclass", num_classes=5, average=None) tensor([1.0000, 1.0000, 0.2500, 0.2500, nan]) """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_average_precision(preds, target, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_average_precision( preds, target, num_classes, average, thresholds, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_average_precision(preds, target, num_labels, average, thresholds, ignore_index, validate_args) raise ValueError( diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index d87f9bb5861..8c06e373772 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -23,6 +23,7 @@ _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, ) +from torchmetrics.utilities.enums import ClassificationTask def _binning_bucketize( @@ -348,9 +349,9 @@ def calibration_error( each argument influence and examples. """ assert norm is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_calibration_error(preds, target, n_bins, norm, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_calibration_error(preds, target, num_classes, n_bins, norm, ignore_index, validate_args) raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multiclass'` but got {task}") diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index e01243bbd63..faa008bea16 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -27,6 +27,7 @@ _multiclass_confusion_matrix_tensor_validation, _multiclass_confusion_matrix_update, ) +from torchmetrics.utilities.enums import ClassificationTask def _cohen_kappa_reduce(confmat: Tensor, weights: Optional[Literal["linear", "quadratic", "none"]] = None) -> Tensor: @@ -256,9 +257,9 @@ class labels. >>> cohen_kappa(preds, target, task="multiclass", num_classes=2) tensor(0.5000) """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_cohen_kappa(preds, target, threshold, weights, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_cohen_kappa(preds, target, num_classes, weights, ignore_index, validate_args) raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multiclass'` but got {task}") diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 047cff8351e..26970967866 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -19,6 +19,7 @@ from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.prints import rank_zero_warn @@ -630,12 +631,12 @@ def confusion_matrix( [[1, 0], [1, 0]], [[0, 1], [0, 1]]]) """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_confusion_matrix(preds, target, threshold, normalize, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_confusion_matrix(preds, target, num_classes, normalize, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_confusion_matrix(preds, target, num_labels, threshold, normalize, ignore_index, validate_args) raise ValueError( diff --git a/src/torchmetrics/functional/classification/exact_match.py b/src/torchmetrics/functional/classification/exact_match.py index b4ca3ce7dd1..df1dc380487 100644 --- a/src/torchmetrics/functional/classification/exact_match.py +++ b/src/torchmetrics/functional/classification/exact_match.py @@ -26,6 +26,7 @@ _multilabel_stat_scores_tensor_validation, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _exact_match_reduce( @@ -229,7 +230,7 @@ def exact_match( >>> exact_match(preds, target, task="multiclass", num_classes=3, multidim_average='samplewise') tensor([1., 0.]) """ - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert num_classes is not None return multiclass_exact_match(preds, target, num_classes, multidim_average, ignore_index, validate_args) if task == "multilalbe": diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index c1a68e13d5b..11a7d191f63 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _fbeta_reduce( @@ -693,15 +694,15 @@ def fbeta_score( tensor(0.3333) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_fbeta_score(preds, target, beta, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_fbeta_score( preds, target, beta, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_fbeta_score( preds, target, beta, num_labels, threshold, average, multidim_average, ignore_index, validate_args @@ -742,15 +743,15 @@ def f1_score( tensor(0.3333) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_f1_score(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_f1_score( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_f1_score( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args diff --git a/src/torchmetrics/functional/classification/hamming.py b/src/torchmetrics/functional/classification/hamming.py index c7725931fb8..9114e3f7ee7 100644 --- a/src/torchmetrics/functional/classification/hamming.py +++ b/src/torchmetrics/functional/classification/hamming.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _hamming_distance_reduce( @@ -399,15 +400,15 @@ def hamming_distance( tensor(0.2500) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_hamming_distance(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_hamming_distance( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_hamming_distance( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index e2ff98141e2..fa270abd5f9 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -24,6 +24,7 @@ _multiclass_confusion_matrix_tensor_validation, ) from torchmetrics.utilities.data import to_onehot +from torchmetrics.utilities.enums import ClassificationTask def _hinge_loss_compute(measure: Tensor, total: Tensor) -> Tensor: @@ -276,9 +277,9 @@ def hinge_loss( >>> hinge_loss(preds, target, task="multiclass", num_classes=3, multiclass_mode="one-vs-all") tensor([1.3743, 1.1945, 1.2359]) """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_hinge_loss(preds, target, squared, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_hinge_loss(preds, target, num_classes, squared, multiclass_mode, ignore_index, validate_args) raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multilabel'` but got {task}") diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index f593a424fd2..2cbf03fb317 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -32,6 +32,7 @@ _multilabel_confusion_matrix_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _jaccard_index_reduce( @@ -321,12 +322,12 @@ def jaccard_index( >>> jaccard_index(pred, target, task="multiclass", num_classes=2) tensor(0.9660) """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_jaccard_index(preds, target, threshold, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_jaccard_index(preds, target, num_classes, average, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_jaccard_index(preds, target, num_labels, threshold, average, ignore_index, validate_args) raise ValueError( diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index cc5f94b70e2..6e1a6c4b5d1 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -31,6 +31,7 @@ _multilabel_confusion_matrix_tensor_validation, _multilabel_confusion_matrix_update, ) +from torchmetrics.utilities.enums import ClassificationTask def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: @@ -235,12 +236,12 @@ def matthews_corrcoef( >>> matthews_corrcoef(preds, target, task="multiclass", num_classes=2) tensor(0.5774) """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_matthews_corrcoef(preds, target, threshold, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_matthews_corrcoef(preds, target, num_classes, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_matthews_corrcoef(preds, target, num_labels, threshold, ignore_index, validate_args) raise ValueError( diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index e00746341aa..ac6725c704a 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _precision_recall_reduce( @@ -652,15 +653,15 @@ def precision( tensor(0.2500) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_precision(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_precision( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_precision( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args @@ -705,15 +706,15 @@ def recall( tensor(0.2500) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_recall(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_recall( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_recall( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 9291f31d26b..8c58d185371 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -22,6 +22,7 @@ from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import ClassificationTask def _binary_clf_curve( @@ -815,12 +816,12 @@ def precision_recall_curve( >>> thresholds [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_precision_recall_curve(preds, target, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_precision_recall_curve(preds, target, num_classes, thresholds, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_precision_recall_curve(preds, target, num_labels, thresholds, ignore_index, validate_args) raise ValueError( diff --git a/src/torchmetrics/functional/classification/recall_at_fixed_precision.py b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py index 8c0e9f38578..89f60a54481 100644 --- a/src/torchmetrics/functional/classification/recall_at_fixed_precision.py +++ b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py @@ -34,6 +34,7 @@ _multilabel_precision_recall_curve_tensor_validation, _multilabel_precision_recall_curve_update, ) +from torchmetrics.utilities.enums import ClassificationTask def _recall_at_precision( @@ -384,14 +385,14 @@ def recall_at_fixed_precision( :func:`binary_recall_at_fixed_precision`, :func:`multiclass_recall_at_fixed_precision` and :func:`multilabel_recall_at_fixed_precision` for the specific details of each argument influence and examples. """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_recall_at_fixed_precision(preds, target, min_precision, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_recall_at_fixed_precision( preds, target, num_classes, min_precision, thresholds, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_recall_at_fixed_precision( preds, target, num_labels, min_precision, thresholds, ignore_index, validate_args diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index 38e46a1755d..575b0376280 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -34,6 +34,7 @@ ) from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _binary_roc_compute( @@ -483,12 +484,12 @@ def roc( tensor([1.0000, 0.7576, 0.3680, 0.3468, 0.0745]), tensor([1.0000, 0.1837, 0.1338, 0.1183, 0.1138])] """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_roc(preds, target, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_roc(preds, target, num_classes, thresholds, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args) raise ValueError( diff --git a/src/torchmetrics/functional/classification/specificity.py b/src/torchmetrics/functional/classification/specificity.py index 1a1c351a2d3..cd167768314 100644 --- a/src/torchmetrics/functional/classification/specificity.py +++ b/src/torchmetrics/functional/classification/specificity.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _specificity_reduce( @@ -370,15 +371,15 @@ def specificity( tensor(0.6250) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_specificity(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_specificity( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_specificity( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args diff --git a/src/torchmetrics/functional/classification/specificity_at_sensitivity.py b/src/torchmetrics/functional/classification/specificity_at_sensitivity.py index c654ea26c61..034ef4e161d 100644 --- a/src/torchmetrics/functional/classification/specificity_at_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_at_sensitivity.py @@ -36,6 +36,7 @@ _multiclass_roc_compute, _multilabel_roc_compute, ) +from torchmetrics.utilities.enums import ClassificationTask def _convert_fpr_to_specificity(fpr: Tensor) -> Tensor: @@ -413,16 +414,16 @@ def specicity_at_sensitivity( :func:`binary_specificity_at_sensitivity`, :func:`multiclass_specicity_at_sensitivity` and :func:`multilabel_specifity_at_sensitvity` for the specific details of each argument influence and examples. """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_specificity_at_sensitivity( # type: ignore preds, target, min_sensitivity, thresholds, ignore_index, validate_args ) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_specificity_at_sensitivity( # type: ignore preds, target, num_classes, min_sensitivity, thresholds, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_specificity_at_sensitivity( # type: ignore preds, target, num_labels, min_sensitivity, thresholds, ignore_index, validate_args diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index e9ad0d5d5e1..fceb25bce27 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -19,7 +19,7 @@ from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification from torchmetrics.utilities.data import _bincount, select_topk -from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod +from torchmetrics.utilities.enums import AverageMethod, ClassificationTask, DataType, MDMCAverageMethod def _binary_stat_scores_arg_validation( @@ -1083,15 +1083,15 @@ def stat_scores( [1, 0, 3, 0, 1]]) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_stat_scores(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_stat_scores( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_stat_scores( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args From 3bec0238c02aa75cfde9c159829af3f22c7e04b0 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 3 Feb 2023 22:24:34 +0900 Subject: [PATCH 03/20] gh: update templates (#1477)Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen * gh: update templates --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/ISSUE_TEMPLATE/bug_report.md | 5 ++++- .github/ISSUE_TEMPLATE/documentation.md | 5 ++--- .github/PULL_REQUEST_TEMPLATE.md | 12 +++++++++--- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 487b6e8bdc2..150f7849963 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -16,11 +16,14 @@ Steps to reproduce the behavior... -#### Code sample +
+ Code sample +
+ ### Expected behavior diff --git a/.github/ISSUE_TEMPLATE/documentation.md b/.github/ISSUE_TEMPLATE/documentation.md index c74b3408000..456d49be454 100644 --- a/.github/ISSUE_TEMPLATE/documentation.md +++ b/.github/ISSUE_TEMPLATE/documentation.md @@ -10,8 +10,7 @@ assignees: '' For typos and doc fixes, please go ahead and: -1. Create an issue. -1. Fix the typo. -1. Submit a PR. +- For a simple typo or fix, please send directly a PR (no need to create an issue) +- If you are not sure about the proper solution, please describe here your finding... Thanks! diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 0c6881d0228..318ee5483eb 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -2,18 +2,24 @@ Fixes #\ -## Before submitting +
+ Before submitting -- [ ] Was this **discussed/approved** via a Github issue? (no need for typos and docs improvements) +- [ ] Was this **discussed/agreed** via a Github issue? (no need for typos and docs improvements) - [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/metrics/blob/master/.github/CONTRIBUTING.md), Pull Request section? - [ ] Did you make sure to **update the docs**? - [ ] Did you write any new **necessary tests**? -## PR review +
+ +
+ PR review Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged. +
+ ## Did you have fun? Make sure you had fun coding 🙃 From 0140d12a38737286dc969baa2d3db502be281c2a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 3 Feb 2023 14:22:45 +0100 Subject: [PATCH 04/20] add enum --- src/torchmetrics/utilities/enums.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index c9672626727..fea1fe2aaac 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -81,3 +81,15 @@ class MDMCAverageMethod(EnumStr): GLOBAL = "global" SAMPLEWISE = "samplewise" + + +class ClassificationTask(EnumStr): + """Enum to represent the different tasks in classification metrics. + + >>> "binary" in list(ClassificationTask) + True + """ + + BINARY = "binary" + MULTICLASS = "multiclass" + MULTILABEL = "multilabel" From 9bd0063d435d35d288bc14b23ce5950ccbcc883a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 3 Feb 2023 14:23:18 +0100 Subject: [PATCH 05/20] add enum --- src/torchmetrics/classification/accuracy.py | 7 ++++--- src/torchmetrics/classification/auroc.py | 7 ++++--- .../classification/average_precision.py | 7 ++++--- .../classification/calibration_error.py | 5 +++-- src/torchmetrics/classification/cohen_kappa.py | 5 +++-- src/torchmetrics/classification/confusion_matrix.py | 7 ++++--- src/torchmetrics/classification/exact_match.py | 5 +++-- src/torchmetrics/classification/f_beta.py | 13 +++++++------ src/torchmetrics/classification/hamming.py | 7 ++++--- src/torchmetrics/classification/hinge.py | 5 +++-- src/torchmetrics/classification/jaccard.py | 7 ++++--- .../classification/matthews_corrcoef.py | 7 ++++--- src/torchmetrics/classification/precision_recall.py | 13 +++++++------ .../classification/precision_recall_curve.py | 7 ++++--- .../classification/recall_at_fixed_precision.py | 7 ++++--- src/torchmetrics/classification/roc.py | 7 ++++--- src/torchmetrics/classification/specificity.py | 7 ++++--- .../classification/specificity_at_sensitivity.py | 7 ++++--- src/torchmetrics/classification/stat_scores.py | 7 ++++--- .../functional/classification/accuracy.py | 7 ++++--- src/torchmetrics/functional/classification/auroc.py | 7 ++++--- .../functional/classification/average_precision.py | 7 ++++--- .../functional/classification/calibration_error.py | 5 +++-- .../functional/classification/cohen_kappa.py | 5 +++-- .../functional/classification/confusion_matrix.py | 7 ++++--- .../functional/classification/exact_match.py | 3 ++- .../functional/classification/f_beta.py | 13 +++++++------ .../functional/classification/hamming.py | 7 ++++--- src/torchmetrics/functional/classification/hinge.py | 5 +++-- .../functional/classification/jaccard.py | 7 ++++--- .../functional/classification/matthews_corrcoef.py | 7 ++++--- .../functional/classification/precision_recall.py | 13 +++++++------ .../classification/precision_recall_curve.py | 7 ++++--- .../classification/recall_at_fixed_precision.py | 7 ++++--- src/torchmetrics/functional/classification/roc.py | 7 ++++--- .../functional/classification/specificity.py | 7 ++++--- .../classification/specificity_at_sensitivity.py | 7 ++++--- .../functional/classification/stat_scores.py | 8 ++++---- 38 files changed, 155 insertions(+), 118 deletions(-) diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 6cb1704ab04..9547a979322 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -18,6 +18,7 @@ from torchmetrics.functional.classification.accuracy import _accuracy_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val @@ -453,13 +454,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryAccuracy(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassAccuracy(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelAccuracy(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 95be484f01b..6b531213f0e 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -31,6 +31,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryAUROC(BinaryPrecisionRecallCurve): @@ -353,12 +354,12 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryAUROC(max_fpr, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassAUROC(num_classes, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelAUROC(num_labels, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 5a78ce0401c..6cbb1492bb2 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -30,6 +30,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryAveragePrecision(BinaryPrecisionRecallCurve): @@ -357,12 +358,12 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryAveragePrecision(**kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassAveragePrecision(num_classes, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelAveragePrecision(num_labels, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 335ee4ae916..c8849dc479d 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -29,6 +29,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryCalibrationError(Metric): @@ -268,9 +269,9 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"n_bins": n_bins, "norm": norm, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryCalibrationError(**kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassCalibrationError(num_classes, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 87045f6c04d..73950596a4a 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -23,6 +23,7 @@ _multiclass_cohen_kappa_arg_validation, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryCohenKappa(BinaryConfusionMatrix): @@ -222,9 +223,9 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"weights": weights, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryCohenKappa(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassCohenKappa(num_classes, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 57c0e931f89..f88e23206ea 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -35,6 +35,7 @@ _multilabel_confusion_matrix_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _PLOT_OUT_TYPE, plot_confusion_matrix @@ -398,12 +399,12 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"normalize": normalize, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryConfusionMatrix(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassConfusionMatrix(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelConfusionMatrix(num_labels, threshold, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 4c0bd64768d..c86ff796bd4 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -32,6 +32,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class MulticlassExactMatch(Metric): @@ -291,10 +292,10 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassExactMatch(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelExactMatch(num_labels, threshold, **kwargs) raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}") diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 01c4b2f05e7..027c800ed3a 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -24,6 +24,7 @@ _multilabel_fbeta_score_arg_validation, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryFBetaScore(BinaryStatScores): @@ -729,13 +730,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryFBetaScore(beta, threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs) raise ValueError( @@ -780,13 +781,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryF1Score(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassF1Score(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelF1Score(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index c4da43ffbf8..c08798c3ec1 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -19,6 +19,7 @@ from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.hamming import _hamming_distance_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryHammingDistance(BinaryStatScores): @@ -349,13 +350,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryHammingDistance(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassHammingDistance(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelHammingDistance(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index c27b84090da..0f3798c892e 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -29,6 +29,7 @@ _multiclass_hinge_loss_update, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryHingeLoss(Metric): @@ -253,9 +254,9 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryHingeLoss(squared, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 637a862466c..589f75b840e 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -23,6 +23,7 @@ _multilabel_jaccard_index_arg_validation, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryJaccardIndex(BinaryConfusionMatrix): @@ -296,12 +297,12 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryJaccardIndex(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassJaccardIndex(num_classes, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index d049c061878..8886c02887f 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -19,6 +19,7 @@ from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.functional.classification.matthews_corrcoef import _matthews_corrcoef_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryMatthewsCorrCoef(BinaryConfusionMatrix): @@ -238,12 +239,12 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryMatthewsCorrCoef(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassMatthewsCorrCoef(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index daff2af6d3a..fc0f5767669 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -19,6 +19,7 @@ from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.precision_recall import _precision_recall_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinaryPrecision(BinaryStatScores): @@ -620,13 +621,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryPrecision(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassPrecision(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelPrecision(num_labels, threshold, average, **kwargs) raise ValueError( @@ -676,13 +677,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryRecall(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassRecall(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelRecall(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 386c994bff9..8e7070a16e1 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -37,6 +37,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryPrecisionRecallCurve(Metric): @@ -467,12 +468,12 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryPrecisionRecallCurve(**kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassPrecisionRecallCurve(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelPrecisionRecallCurve(num_labels, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/recall_at_fixed_precision.py b/src/torchmetrics/classification/recall_at_fixed_precision.py index 2f57eb05e8f..516761d62e3 100644 --- a/src/torchmetrics/classification/recall_at_fixed_precision.py +++ b/src/torchmetrics/classification/recall_at_fixed_precision.py @@ -31,6 +31,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryRecallAtFixedPrecision(BinaryPrecisionRecallCurve): @@ -323,14 +324,14 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryRecallAtFixedPrecision(min_precision, thresholds, ignore_index, validate_args, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassRecallAtFixedPrecision( num_classes, min_precision, thresholds, ignore_index, validate_args, **kwargs ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelRecallAtFixedPrecision( num_labels, min_precision, thresholds, ignore_index, validate_args, **kwargs diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 32c2aa5cbaf..5984fad6dd8 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -28,6 +28,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinaryROC(BinaryPrecisionRecallCurve): @@ -382,12 +383,12 @@ def __new__( **kwargs: Any, ) -> Metric: kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryROC(**kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassROC(num_classes, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelROC(num_labels, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index f8a65ca2a21..4c249d43c7c 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -19,6 +19,7 @@ from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.specificity import _specificity_reduce from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask class BinarySpecificity(BinaryStatScores): @@ -324,13 +325,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinarySpecificity(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassSpecificity(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelSpecificity(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/classification/specificity_at_sensitivity.py b/src/torchmetrics/classification/specificity_at_sensitivity.py index f31935f9d0c..a9effa59d84 100644 --- a/src/torchmetrics/classification/specificity_at_sensitivity.py +++ b/src/torchmetrics/classification/specificity_at_sensitivity.py @@ -31,6 +31,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class BinarySpecificityAtSensitivity(BinaryPrecisionRecallCurve): @@ -327,14 +328,14 @@ def __new__( # type: ignore validate_args: bool = True, **kwargs: Any, ) -> Metric: - if task == "binary": + if task == ClassificationTask.BINARY: return BinarySpecificityAtSensitivity(min_sensitivity, thresholds, ignore_index, validate_args, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return MulticlassSpecificityAtSensitivity( num_classes, min_sensitivity, thresholds, ignore_index, validate_args, **kwargs ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelSpecificityAtSensitivity( num_labels, min_sensitivity, thresholds, ignore_index, validate_args, **kwargs diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 0bd18ae41d8..f2604cbd80c 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -36,6 +36,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.enums import ClassificationTask class _AbstractStatScores(Metric): @@ -499,13 +500,13 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == "binary": + if task == ClassificationTask.BINARY: return BinaryStatScores(threshold, **kwargs) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return MulticlassStatScores(num_classes, top_k, average, **kwargs) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelStatScores(num_labels, threshold, average, **kwargs) raise ValueError( diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index a6e08093a5b..fa36ccc92d1 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _accuracy_reduce( @@ -397,15 +398,15 @@ def accuracy( tensor(0.6667) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_accuracy(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_accuracy( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_accuracy( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index ddfccb392c4..875cd0846b4 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -38,6 +38,7 @@ ) from torchmetrics.utilities.compute import _auc_compute_without_check, _safe_divide from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.prints import rank_zero_warn @@ -450,12 +451,12 @@ def auroc( >>> auroc(preds, target, task='multiclass', num_classes=3) tensor(0.7778) """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_auroc(preds, target, max_fpr, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_auroc(preds, target, num_classes, average, thresholds, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_auroc(preds, target, num_labels, average, thresholds, ignore_index, validate_args) raise ValueError( diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 95507e6e2ee..4c8d84ab9b8 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -36,6 +36,7 @@ ) from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.prints import rank_zero_warn @@ -438,14 +439,14 @@ def average_precision( >>> average_precision(pred, target, task="multiclass", num_classes=5, average=None) tensor([1.0000, 1.0000, 0.2500, 0.2500, nan]) """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_average_precision(preds, target, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_average_precision( preds, target, num_classes, average, thresholds, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_average_precision(preds, target, num_labels, average, thresholds, ignore_index, validate_args) raise ValueError( diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index d87f9bb5861..8c06e373772 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -23,6 +23,7 @@ _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, ) +from torchmetrics.utilities.enums import ClassificationTask def _binning_bucketize( @@ -348,9 +349,9 @@ def calibration_error( each argument influence and examples. """ assert norm is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_calibration_error(preds, target, n_bins, norm, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_calibration_error(preds, target, num_classes, n_bins, norm, ignore_index, validate_args) raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multiclass'` but got {task}") diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index e01243bbd63..faa008bea16 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -27,6 +27,7 @@ _multiclass_confusion_matrix_tensor_validation, _multiclass_confusion_matrix_update, ) +from torchmetrics.utilities.enums import ClassificationTask def _cohen_kappa_reduce(confmat: Tensor, weights: Optional[Literal["linear", "quadratic", "none"]] = None) -> Tensor: @@ -256,9 +257,9 @@ class labels. >>> cohen_kappa(preds, target, task="multiclass", num_classes=2) tensor(0.5000) """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_cohen_kappa(preds, target, threshold, weights, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_cohen_kappa(preds, target, num_classes, weights, ignore_index, validate_args) raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multiclass'` but got {task}") diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 047cff8351e..26970967866 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -19,6 +19,7 @@ from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import ClassificationTask from torchmetrics.utilities.prints import rank_zero_warn @@ -630,12 +631,12 @@ def confusion_matrix( [[1, 0], [1, 0]], [[0, 1], [0, 1]]]) """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_confusion_matrix(preds, target, threshold, normalize, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_confusion_matrix(preds, target, num_classes, normalize, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_confusion_matrix(preds, target, num_labels, threshold, normalize, ignore_index, validate_args) raise ValueError( diff --git a/src/torchmetrics/functional/classification/exact_match.py b/src/torchmetrics/functional/classification/exact_match.py index b4ca3ce7dd1..df1dc380487 100644 --- a/src/torchmetrics/functional/classification/exact_match.py +++ b/src/torchmetrics/functional/classification/exact_match.py @@ -26,6 +26,7 @@ _multilabel_stat_scores_tensor_validation, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _exact_match_reduce( @@ -229,7 +230,7 @@ def exact_match( >>> exact_match(preds, target, task="multiclass", num_classes=3, multidim_average='samplewise') tensor([1., 0.]) """ - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert num_classes is not None return multiclass_exact_match(preds, target, num_classes, multidim_average, ignore_index, validate_args) if task == "multilalbe": diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index c1a68e13d5b..11a7d191f63 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _fbeta_reduce( @@ -693,15 +694,15 @@ def fbeta_score( tensor(0.3333) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_fbeta_score(preds, target, beta, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_fbeta_score( preds, target, beta, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_fbeta_score( preds, target, beta, num_labels, threshold, average, multidim_average, ignore_index, validate_args @@ -742,15 +743,15 @@ def f1_score( tensor(0.3333) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_f1_score(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_f1_score( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_f1_score( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args diff --git a/src/torchmetrics/functional/classification/hamming.py b/src/torchmetrics/functional/classification/hamming.py index c7725931fb8..9114e3f7ee7 100644 --- a/src/torchmetrics/functional/classification/hamming.py +++ b/src/torchmetrics/functional/classification/hamming.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _hamming_distance_reduce( @@ -399,15 +400,15 @@ def hamming_distance( tensor(0.2500) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_hamming_distance(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_hamming_distance( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_hamming_distance( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index e2ff98141e2..fa270abd5f9 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -24,6 +24,7 @@ _multiclass_confusion_matrix_tensor_validation, ) from torchmetrics.utilities.data import to_onehot +from torchmetrics.utilities.enums import ClassificationTask def _hinge_loss_compute(measure: Tensor, total: Tensor) -> Tensor: @@ -276,9 +277,9 @@ def hinge_loss( >>> hinge_loss(preds, target, task="multiclass", num_classes=3, multiclass_mode="one-vs-all") tensor([1.3743, 1.1945, 1.2359]) """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_hinge_loss(preds, target, squared, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_hinge_loss(preds, target, num_classes, squared, multiclass_mode, ignore_index, validate_args) raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multilabel'` but got {task}") diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index f593a424fd2..2cbf03fb317 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -32,6 +32,7 @@ _multilabel_confusion_matrix_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _jaccard_index_reduce( @@ -321,12 +322,12 @@ def jaccard_index( >>> jaccard_index(pred, target, task="multiclass", num_classes=2) tensor(0.9660) """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_jaccard_index(preds, target, threshold, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_jaccard_index(preds, target, num_classes, average, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_jaccard_index(preds, target, num_labels, threshold, average, ignore_index, validate_args) raise ValueError( diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index cc5f94b70e2..6e1a6c4b5d1 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -31,6 +31,7 @@ _multilabel_confusion_matrix_tensor_validation, _multilabel_confusion_matrix_update, ) +from torchmetrics.utilities.enums import ClassificationTask def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: @@ -235,12 +236,12 @@ def matthews_corrcoef( >>> matthews_corrcoef(preds, target, task="multiclass", num_classes=2) tensor(0.5774) """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_matthews_corrcoef(preds, target, threshold, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_matthews_corrcoef(preds, target, num_classes, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_matthews_corrcoef(preds, target, num_labels, threshold, ignore_index, validate_args) raise ValueError( diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index e00746341aa..ac6725c704a 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _precision_recall_reduce( @@ -652,15 +653,15 @@ def precision( tensor(0.2500) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_precision(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_precision( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_precision( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args @@ -705,15 +706,15 @@ def recall( tensor(0.2500) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_recall(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_recall( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_recall( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index 9291f31d26b..8c58d185371 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -22,6 +22,7 @@ from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.enums import ClassificationTask def _binary_clf_curve( @@ -815,12 +816,12 @@ def precision_recall_curve( >>> thresholds [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_precision_recall_curve(preds, target, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_precision_recall_curve(preds, target, num_classes, thresholds, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_precision_recall_curve(preds, target, num_labels, thresholds, ignore_index, validate_args) raise ValueError( diff --git a/src/torchmetrics/functional/classification/recall_at_fixed_precision.py b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py index 8c0e9f38578..89f60a54481 100644 --- a/src/torchmetrics/functional/classification/recall_at_fixed_precision.py +++ b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py @@ -34,6 +34,7 @@ _multilabel_precision_recall_curve_tensor_validation, _multilabel_precision_recall_curve_update, ) +from torchmetrics.utilities.enums import ClassificationTask def _recall_at_precision( @@ -384,14 +385,14 @@ def recall_at_fixed_precision( :func:`binary_recall_at_fixed_precision`, :func:`multiclass_recall_at_fixed_precision` and :func:`multilabel_recall_at_fixed_precision` for the specific details of each argument influence and examples. """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_recall_at_fixed_precision(preds, target, min_precision, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_recall_at_fixed_precision( preds, target, num_classes, min_precision, thresholds, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_recall_at_fixed_precision( preds, target, num_labels, min_precision, thresholds, ignore_index, validate_args diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index 38e46a1755d..575b0376280 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -34,6 +34,7 @@ ) from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _binary_roc_compute( @@ -483,12 +484,12 @@ def roc( tensor([1.0000, 0.7576, 0.3680, 0.3468, 0.0745]), tensor([1.0000, 0.1837, 0.1338, 0.1183, 0.1138])] """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_roc(preds, target, thresholds, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_roc(preds, target, num_classes, thresholds, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args) raise ValueError( diff --git a/src/torchmetrics/functional/classification/specificity.py b/src/torchmetrics/functional/classification/specificity.py index 1a1c351a2d3..cd167768314 100644 --- a/src/torchmetrics/functional/classification/specificity.py +++ b/src/torchmetrics/functional/classification/specificity.py @@ -32,6 +32,7 @@ _multilabel_stat_scores_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.enums import ClassificationTask def _specificity_reduce( @@ -370,15 +371,15 @@ def specificity( tensor(0.6250) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_specificity(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_specificity( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_specificity( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args diff --git a/src/torchmetrics/functional/classification/specificity_at_sensitivity.py b/src/torchmetrics/functional/classification/specificity_at_sensitivity.py index c654ea26c61..034ef4e161d 100644 --- a/src/torchmetrics/functional/classification/specificity_at_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_at_sensitivity.py @@ -36,6 +36,7 @@ _multiclass_roc_compute, _multilabel_roc_compute, ) +from torchmetrics.utilities.enums import ClassificationTask def _convert_fpr_to_specificity(fpr: Tensor) -> Tensor: @@ -413,16 +414,16 @@ def specicity_at_sensitivity( :func:`binary_specificity_at_sensitivity`, :func:`multiclass_specicity_at_sensitivity` and :func:`multilabel_specifity_at_sensitvity` for the specific details of each argument influence and examples. """ - if task == "binary": + if task == ClassificationTask.BINARY: return binary_specificity_at_sensitivity( # type: ignore preds, target, min_sensitivity, thresholds, ignore_index, validate_args ) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) return multiclass_specificity_at_sensitivity( # type: ignore preds, target, num_classes, min_sensitivity, thresholds, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_specificity_at_sensitivity( # type: ignore preds, target, num_labels, min_sensitivity, thresholds, ignore_index, validate_args diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index e9ad0d5d5e1..fceb25bce27 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -19,7 +19,7 @@ from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification from torchmetrics.utilities.data import _bincount, select_topk -from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod +from torchmetrics.utilities.enums import AverageMethod, ClassificationTask, DataType, MDMCAverageMethod def _binary_stat_scores_arg_validation( @@ -1083,15 +1083,15 @@ def stat_scores( [1, 0, 3, 0, 1]]) """ assert multidim_average is not None - if task == "binary": + if task == ClassificationTask.BINARY: return binary_stat_scores(preds, target, threshold, multidim_average, ignore_index, validate_args) - if task == "multiclass": + if task == ClassificationTask.MULTICLASS: assert isinstance(num_classes, int) assert isinstance(top_k, int) return multiclass_stat_scores( preds, target, num_classes, average, top_k, multidim_average, ignore_index, validate_args ) - if task == "multilabel": + if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_stat_scores( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args From 5a4caef9d59f788278de482c45e7109a27b19b51 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 3 Feb 2023 14:35:12 +0100 Subject: [PATCH 06/20] StrEnum --- src/torchmetrics/utilities/enums.py | 40 ++++------------------------- 1 file changed, 5 insertions(+), 35 deletions(-) diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index fea1fe2aaac..b56e2a8e25d 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -11,41 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum -from typing import Optional, Union +from lightning_utilities.core.enums import StrEnum -class EnumStr(str, Enum): - """Type of any enumerator with allowed comparison to string invariant to cases. - Example: - >>> class MyEnum(EnumStr): - ... ABC = 'abc' - >>> MyEnum.from_str('Abc') - - >>> {MyEnum.ABC: 123} - {: 123} - """ - - @classmethod - def from_str(cls, value: str) -> Optional["EnumStr"]: - statuses = [status for status in dir(cls) if not status.startswith("_")] - for st in statuses: - if st.lower() == value.lower(): - return getattr(cls, st) - return None - - def __eq__(self, other: Union[str, "EnumStr", None]) -> bool: # type: ignore - other = other.value if isinstance(other, Enum) else str(other) - return self.value.lower() == other.lower() - - def __hash__(self) -> int: - # re-enable hashtable so it can be used as a dict key or in a set - # example: set(EnumStr) - return hash(self.name) - - -class DataType(EnumStr): +class DataType(StrEnum): """Enum to represent data type. >>> "Binary" in list(DataType) @@ -58,7 +28,7 @@ class DataType(EnumStr): MULTIDIM_MULTICLASS = "multi-dim multi-class" -class AverageMethod(EnumStr): +class AverageMethod(StrEnum): """Enum to represent average method. >>> None in list(AverageMethod) @@ -76,14 +46,14 @@ class AverageMethod(EnumStr): SAMPLES = "samples" -class MDMCAverageMethod(EnumStr): +class MDMCAverageMethod(StrEnum): """Enum to represent multi-dim multi-class average method.""" GLOBAL = "global" SAMPLEWISE = "samplewise" -class ClassificationTask(EnumStr): +class ClassificationTask(StrEnum): """Enum to represent the different tasks in classification metrics. >>> "binary" in list(ClassificationTask) From 169586fd817e9fe94bd6fd3f4d45195ff45f04bf Mon Sep 17 00:00:00 2001 From: Jirka Date: Sun, 5 Feb 2023 02:14:28 +0100 Subject: [PATCH 07/20] utils 0.5.0 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a121c783161..4b5ec226720 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,4 @@ numpy>=1.17.2 torch>=1.8.1 typing-extensions; python_version < '3.9' packaging # hotfix for utils, can be dropped with lit-utils >=0.5 -lightning-utilities>=0.4.1 +lightning-utilities>=0.5.0 From 2720af1f43da9324b80f93fee94da71f93fef921 Mon Sep 17 00:00:00 2001 From: Jirka Date: Sun, 5 Feb 2023 02:30:30 +0100 Subject: [PATCH 08/20] with error --- src/torchmetrics/functional/text/infolm.py | 15 +--------- src/torchmetrics/utilities/enums.py | 33 ++++++++++++++++++---- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/src/torchmetrics/functional/text/infolm.py b/src/torchmetrics/functional/text/infolm.py index 1afc40c6023..38c190be0fd 100644 --- a/src/torchmetrics/functional/text/infolm.py +++ b/src/torchmetrics/functional/text/infolm.py @@ -54,6 +54,7 @@ class _IMEnum(EnumStr): """A helper Enum class for storing the information measure.""" + task = "Information measure" KL_DIVERGENCE = "kl_divergence" ALPHA_DIVERGENCE = "alpha_divergence" BETA_DIVERGENCE = "beta_divergence" @@ -64,20 +65,6 @@ class _IMEnum(EnumStr): L_INFINITY_DISTANCE = "l_infinity_distance" FISHER_RAO_DISTANCE = "fisher_rao_distance" - @classmethod - def from_str(cls, value: str) -> Optional["EnumStr"]: - """ - Raises: - ValueError: - If required information measure is not among the supported options. - """ - _allowed_im = [im.lower() for im in _IMEnum._member_names_] - - enum_key = super().from_str(value) - if enum_key is not None and enum_key in _allowed_im: - return enum_key - raise ValueError(f"Invalid information measure. Expected one of {_allowed_im}, but got {enum_key}.") - class _InformationMeasure: """A wrapper class used for the calculation the result of information measure between the discrete reference diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index b56e2a8e25d..83f977cc734 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -11,24 +11,44 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional -from lightning_utilities.core.enums import StrEnum +from lightning_utilities.core.enums import StrEnum as _StrEnum -class DataType(StrEnum): +class EnumStr(_StrEnum): + task: str = "Task" + + @classmethod + def from_str(cls, value: str) -> Optional["EnumStr"]: + """ + Raises: + ValueError: + If required information measure is not among the supported options. + """ + _allowed_im = [im.lower() for im in cls._member_names_] + + enum_key = super().from_str(value) + if enum_key is not None and enum_key in _allowed_im: + return enum_key + raise ValueError(f"Invalid {cls.task}: expected one of {_allowed_im}, but got {enum_key}.") + + +class DataType(EnumStr): """Enum to represent data type. >>> "Binary" in list(DataType) True """ + task = "Data type" BINARY = "binary" MULTILABEL = "multi-label" MULTICLASS = "multi-class" MULTIDIM_MULTICLASS = "multi-dim multi-class" -class AverageMethod(StrEnum): +class AverageMethod(EnumStr): """Enum to represent average method. >>> None in list(AverageMethod) @@ -39,6 +59,7 @@ class AverageMethod(StrEnum): True """ + task = "Average method" MICRO = "micro" MACRO = "macro" WEIGHTED = "weighted" @@ -46,20 +67,22 @@ class AverageMethod(StrEnum): SAMPLES = "samples" -class MDMCAverageMethod(StrEnum): +class MDMCAverageMethod(EnumStr): """Enum to represent multi-dim multi-class average method.""" + task = "MDMC Average method" GLOBAL = "global" SAMPLEWISE = "samplewise" -class ClassificationTask(StrEnum): +class ClassificationTask(EnumStr): """Enum to represent the different tasks in classification metrics. >>> "binary" in list(ClassificationTask) True """ + task = "Classification" BINARY = "binary" MULTICLASS = "multiclass" MULTILABEL = "multilabel" From a6cd44f60a65ab131dac402da03fbdbf1e09bf0b Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 6 Feb 2023 06:17:03 +0100 Subject: [PATCH 09/20] links --- docs/source/links.rst | 4 +- src/torchmetrics/classification/accuracy.py | 40 +++++++++---------- src/torchmetrics/functional/image/psnr.py | 8 ++-- src/torchmetrics/functional/regression/mae.py | 4 +- src/torchmetrics/image/psnr.py | 2 +- src/torchmetrics/regression/mae.py | 4 +- 6 files changed, 31 insertions(+), 31 deletions(-) diff --git a/docs/source/links.rst b/docs/source/links.rst index f5b847fe601..ed3f7732cff 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -64,10 +64,10 @@ .. _Improved Techniques for Training GANs: https://arxiv.org/abs/1606.03498 .. _KID Score: https://github.com/toshas/torch-fidelity/blob/v0.3.0/torch_fidelity/metric_kid.py .. _Demystifying MMD GANs: https://arxiv.org/abs/1801.01401 -.. _Computes Peak Signal-to-Noise Ratio: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio +.. _Compute Peak Signal-to-Noise Ratio: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio .. _Turn a Metric into a Bootstrapped: https://en.wikipedia.org/wiki/Bootstrapping_(statistics) .. _Metric Test for Reset: https://github.com/Lightning-AI/pytorch-lightning/pull/7055 -.. _Computes Mean Absolute Error: https://en.wikipedia.org/wiki/Mean_absolute_error +.. _Compute Mean Absolute Error: https://en.wikipedia.org/wiki/Mean_absolute_error .. _Mean Absolute Percentage Error: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error .. _mean squared error: https://en.wikipedia.org/wiki/Mean_squared_error .. _Aggregate the statistics from multiple devices: https://stackoverflow.com/questions/68395368/estimate-running-correlation-on-multiple-nodes diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 6cb1704ab04..bc52756ba74 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -32,7 +32,7 @@ class BinaryAccuracy(BinaryStatScores): - r"""Computes `Accuracy`_ for binary tasks: + r"""Compute `Accuracy`_ for binary tasks. .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) @@ -41,16 +41,16 @@ class BinaryAccuracy(BinaryStatScores): As input to ``forward`` and ``update`` the metric accepts the following input: - - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating - point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid - per element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``. - - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` + - ``preds`` (:class:`~torch.Tensor`): An int or float tensor of shape ``(N, ...)``. If preds is a floating + point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid + per element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` As output to ``forward`` and ``compute`` the metric returns the following output: - - ``ba`` (:class:`~torch.Tensor`): If ``multidim_average`` is set to ``global``, the metric returns a scalar value. - If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar - value per sample. + - ``ba`` (:class:`~torch.Tensor`): If ``multidim_average`` is set to ``global``, metric returns a scalar value. + If ``multidim_average`` is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar + value per sample. Args: threshold: Threshold for transforming probability to binary {0,1} predictions @@ -161,25 +161,25 @@ class MulticlassAccuracy(MulticlassStatScores): As input to ``forward`` and ``update`` the metric accepts the following input: - - ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` or float tensor of shape ``(N, C, ..)``. - If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert - probabilities/logits into an int tensor. - - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` + - ``preds`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` or float tensor + of shape ``(N, C, ..)``. If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension + to automatically convert probabilities/logits into an int tensor. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` As output to ``forward`` and ``compute`` the metric returns the following output: - - ``mca`` (:class:`~torch.Tensor`): A tensor with the accuracy score whose returned shape depends on the - ``average`` and ``multidim_average`` arguments: + - ``mca`` (:class:`~torch.Tensor`): A tensor with the accuracy score whose returned shape depends on the + ``average`` and ``multidim_average`` arguments: - - If ``multidim_average`` is set to ``global``: + - If ``multidim_average`` is set to ``global``: - - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor - - If ``average=None/'none'``, the shape will be ``(C,)`` + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` - - If ``multidim_average`` is set to ``samplewise``: + - If ``multidim_average`` is set to ``samplewise``: - - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` - - If ``average=None/'none'``, the shape will be ``(N, C)`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` Args: num_classes: Integer specifing the number of classes diff --git a/src/torchmetrics/functional/image/psnr.py b/src/torchmetrics/functional/image/psnr.py index 77db51915d0..1fb7024f1e1 100644 --- a/src/torchmetrics/functional/image/psnr.py +++ b/src/torchmetrics/functional/image/psnr.py @@ -27,7 +27,7 @@ def _psnr_compute( base: float = 10.0, reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", ) -> Tensor: - """Computes peak signal-to-noise ratio. + """Compute peak signal-to-noise ratio. Args: sum_squared_error: Sum of square of errors over all observations @@ -60,13 +60,13 @@ def _psnr_update( target: Tensor, dim: Optional[Union[int, Tuple[int, ...]]] = None, ) -> Tuple[Tensor, Tensor]: - """Updates and returns variables required to compute peak signal-to-noise ratio. + """Update and return variables required to compute peak signal-to-noise ratio. Args: preds: Predicted tensor target: Ground truth tensor - dim: Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is - None meaning scores will be reduced across all dimensions. + dim: Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. + Default is None meaning scores will be reduced across all dimensions. """ if dim is None: diff --git a/src/torchmetrics/functional/regression/mae.py b/src/torchmetrics/functional/regression/mae.py index f3508c02cea..4f38f8f4397 100644 --- a/src/torchmetrics/functional/regression/mae.py +++ b/src/torchmetrics/functional/regression/mae.py @@ -37,7 +37,7 @@ def _mean_absolute_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, def _mean_absolute_error_compute(sum_abs_error: Tensor, n_obs: int) -> Tensor: - """Computes Mean Absolute Error. + """Compute Mean Absolute Error. Args: sum_abs_error: Sum of absolute value of errors over all observations @@ -55,7 +55,7 @@ def _mean_absolute_error_compute(sum_abs_error: Tensor, n_obs: int) -> Tensor: def mean_absolute_error(preds: Tensor, target: Tensor) -> Tensor: - """Computes mean absolute error. + """Compute mean absolute error. Args: preds: estimated labels diff --git a/src/torchmetrics/image/psnr.py b/src/torchmetrics/image/psnr.py index ca2e2ad6f5b..8b4eb0c169d 100644 --- a/src/torchmetrics/image/psnr.py +++ b/src/torchmetrics/image/psnr.py @@ -23,7 +23,7 @@ class PeakSignalNoiseRatio(Metric): - r"""Computes `Computes Peak Signal-to-Noise Ratio`_ (PSNR): + r"""`Compute Peak Signal-to-Noise Ratio`_ (PSNR): .. math:: \text{PSNR}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)}\right) diff --git a/src/torchmetrics/regression/mae.py b/src/torchmetrics/regression/mae.py index 33f1b5a1e67..79c5b243f54 100644 --- a/src/torchmetrics/regression/mae.py +++ b/src/torchmetrics/regression/mae.py @@ -20,7 +20,7 @@ class MeanAbsoluteError(Metric): - r"""`Computes Mean Absolute Error`_ (MAE): + r"""`Compute Mean Absolute Error`_ (MAE): .. math:: \text{MAE} = \frac{1}{N}\sum_i^N | y_i - \hat{y_i} | @@ -70,5 +70,5 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.total += n_obs def compute(self) -> Tensor: - """Computes mean absolute error over state.""" + """Compute mean absolute error over state.""" return _mean_absolute_error_compute(self.sum_abs_error, self.total) From fa0fa1653d6d34da25dac54c0aeb05cf74df5ced Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 6 Feb 2023 13:03:56 +0100 Subject: [PATCH 10/20] property --- .../functional/regression/kendall.py | 36 +++++-------------- src/torchmetrics/functional/text/infolm.py | 5 ++- src/torchmetrics/utilities/enums.py | 28 +++++++++++---- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 3db0cf2b565..28a8d145f10 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -26,46 +26,26 @@ class _MetricVariant(EnumStr): """Enumerate for metric variants.""" + @property + def task(self) -> str: + return "variant" + A = "a" B = "b" C = "c" - @classmethod - def from_str(cls, value: Literal["a", "b", "c"]) -> "_MetricVariant": # type: ignore[override] - """ - Raises: - ValueError: - If required metric variant is not among the supported options. - """ - _allowed_variants = [im.lower() for im in _MetricVariant._member_names_] - - enum_key = super().from_str(value) - if enum_key is not None and enum_key in _allowed_variants: - return enum_key # type: ignore[return-value] # use override - raise ValueError(f"Invalid metric variant. Expected one of {_allowed_variants}, but got {enum_key}.") - class _TestAlternative(EnumStr): """Enumerate for test altenative options.""" + @property + def task(self) -> str: + return "alternative" + TWO_SIDED = "two-sided" LESS = "less" GREATER = "greater" - @classmethod - def from_str(cls, value: Literal["two-sided", "less", "greater"]) -> "_TestAlternative": # type: ignore[override] - """ - Raises: - ValueError: - If required test alternative is not among the supported options. - """ - _allowed_alternatives = [im.lower().replace("_", "-") for im in _TestAlternative._member_names_] - - enum_key = super().from_str(value.replace("-", "_")) - if enum_key is not None and enum_key in _allowed_alternatives: - return enum_key # type: ignore[return-value] # use override - raise ValueError(f"Invalid test alternative. Expected one of {_allowed_alternatives}, but got {enum_key}.") - def _sort_on_first_sequence(x: Tensor, y: Tensor) -> Tuple[Tensor, Tensor]: """Sort sequences in an ascent order according to the sequence ``x``.""" diff --git a/src/torchmetrics/functional/text/infolm.py b/src/torchmetrics/functional/text/infolm.py index 38c190be0fd..f27ff97b27e 100644 --- a/src/torchmetrics/functional/text/infolm.py +++ b/src/torchmetrics/functional/text/infolm.py @@ -54,7 +54,10 @@ class _IMEnum(EnumStr): """A helper Enum class for storing the information measure.""" - task = "Information measure" + @property + def task(self) -> str: + return "Information measure" + KL_DIVERGENCE = "kl_divergence" ALPHA_DIVERGENCE = "alpha_divergence" BETA_DIVERGENCE = "beta_divergence" diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index 83f977cc734..3375ebdce92 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -13,11 +13,13 @@ # limitations under the License. from typing import Optional -from lightning_utilities.core.enums import StrEnum as _StrEnum +from lightning_utilities.core.enums import StrEnum as StrEnum -class EnumStr(_StrEnum): - task: str = "Task" +class EnumStr(StrEnum): + @property + def task(self) -> str: + return "Task" @classmethod def from_str(cls, value: str) -> Optional["EnumStr"]: @@ -41,7 +43,10 @@ class DataType(EnumStr): True """ - task = "Data type" + @property + def task(self) -> str: + return "Data type" + BINARY = "binary" MULTILABEL = "multi-label" MULTICLASS = "multi-class" @@ -59,7 +64,10 @@ class AverageMethod(EnumStr): True """ - task = "Average method" + @property + def task(self) -> str: + return "Average method" + MICRO = "micro" MACRO = "macro" WEIGHTED = "weighted" @@ -70,7 +78,10 @@ class AverageMethod(EnumStr): class MDMCAverageMethod(EnumStr): """Enum to represent multi-dim multi-class average method.""" - task = "MDMC Average method" + @property + def task(self) -> str: + return "MDMC Average method" + GLOBAL = "global" SAMPLEWISE = "samplewise" @@ -82,7 +93,10 @@ class ClassificationTask(EnumStr): True """ - task = "Classification" + @property + def task(self) -> str: + return "Classification" + BINARY = "binary" MULTICLASS = "multiclass" MULTILABEL = "multilabel" From 64da325595d344263d31b05cae9b61d5a48dd09d Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 6 Feb 2023 13:04:47 +0100 Subject: [PATCH 11/20] _name --- src/torchmetrics/functional/regression/kendall.py | 4 ++-- src/torchmetrics/functional/text/infolm.py | 2 +- src/torchmetrics/utilities/enums.py | 12 ++++++------ 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index 28a8d145f10..6a4cb3c369b 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -27,7 +27,7 @@ class _MetricVariant(EnumStr): """Enumerate for metric variants.""" @property - def task(self) -> str: + def _name(self) -> str: return "variant" A = "a" @@ -39,7 +39,7 @@ class _TestAlternative(EnumStr): """Enumerate for test altenative options.""" @property - def task(self) -> str: + def _name(self) -> str: return "alternative" TWO_SIDED = "two-sided" diff --git a/src/torchmetrics/functional/text/infolm.py b/src/torchmetrics/functional/text/infolm.py index f27ff97b27e..3acc092b92f 100644 --- a/src/torchmetrics/functional/text/infolm.py +++ b/src/torchmetrics/functional/text/infolm.py @@ -55,7 +55,7 @@ class _IMEnum(EnumStr): """A helper Enum class for storing the information measure.""" @property - def task(self) -> str: + def _name(self) -> str: return "Information measure" KL_DIVERGENCE = "kl_divergence" diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index 3375ebdce92..78a4b7ac006 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -18,7 +18,7 @@ class EnumStr(StrEnum): @property - def task(self) -> str: + def _name(self) -> str: return "Task" @classmethod @@ -33,7 +33,7 @@ def from_str(cls, value: str) -> Optional["EnumStr"]: enum_key = super().from_str(value) if enum_key is not None and enum_key in _allowed_im: return enum_key - raise ValueError(f"Invalid {cls.task}: expected one of {_allowed_im}, but got {enum_key}.") + raise ValueError(f"Invalid {cls._name}: expected one of {_allowed_im}, but got {enum_key}.") class DataType(EnumStr): @@ -44,7 +44,7 @@ class DataType(EnumStr): """ @property - def task(self) -> str: + def _name(self) -> str: return "Data type" BINARY = "binary" @@ -65,7 +65,7 @@ class AverageMethod(EnumStr): """ @property - def task(self) -> str: + def _name(self) -> str: return "Average method" MICRO = "micro" @@ -79,7 +79,7 @@ class MDMCAverageMethod(EnumStr): """Enum to represent multi-dim multi-class average method.""" @property - def task(self) -> str: + def _name(self) -> str: return "MDMC Average method" GLOBAL = "global" @@ -94,7 +94,7 @@ class ClassificationTask(EnumStr): """ @property - def task(self) -> str: + def _name(self) -> str: return "Classification" BINARY = "binary" From 0ca463d2d10b8f1206a4735572662fb26dceab7f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Feb 2023 12:08:08 +0000 Subject: [PATCH 12/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/utilities/enums.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index 78a4b7ac006..c8294cf6703 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -23,10 +23,9 @@ def _name(self) -> str: @classmethod def from_str(cls, value: str) -> Optional["EnumStr"]: - """ - Raises: - ValueError: - If required information measure is not among the supported options. + """Raises: + ValueError: + If required information measure is not among the supported options. """ _allowed_im = [im.lower() for im in cls._member_names_] From 53554830061f2155fe16271640da58a3263a9921 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 6 Feb 2023 13:23:08 +0100 Subject: [PATCH 13/20] chlog --- CHANGELOG.md | 6 ++++++ src/torchmetrics/utilities/enums.py | 8 +++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cf7115e8b8..6815c621bd8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `classes` to output from `MAP` metric ([#1419](https://github.com/Lightning-AI/metrics/pull/1419)) +- Add `ClassificationTask` Enum and use in metrics ([#1479](https://github.com/Lightning-AI/metrics/pull/1479)) + + ### Changed - Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370)) @@ -28,6 +31,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Raise exception for invalid kwargs in Metric base class ([#1427](https://github.com/Lightning-AI/metrics/pull/1427)) +- Extend `EnumStr` raising `ValueError` for invalid value ([#1479](https://github.com/Lightning-AI/metrics/pull/1479)) + + ### Deprecated - diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index c8294cf6703..9acaaecd35d 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -23,9 +23,11 @@ def _name(self) -> str: @classmethod def from_str(cls, value: str) -> Optional["EnumStr"]: - """Raises: - ValueError: - If required information measure is not among the supported options. + """Load from string. + + Raises: + ValueError: + If required information measure is not among the supported options. """ _allowed_im = [im.lower() for im in cls._member_names_] From fbcc0e52a1736fde069a085ad187ea616cc65dd3 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 6 Feb 2023 13:48:54 +0100 Subject: [PATCH 14/20] docstring --- .../functional/regression/kendall.py | 8 ++-- src/torchmetrics/functional/text/infolm.py | 4 +- src/torchmetrics/utilities/enums.py | 39 ++++++++++++------- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/src/torchmetrics/functional/regression/kendall.py b/src/torchmetrics/functional/regression/kendall.py index f05d9e695f2..f96fecca4dd 100644 --- a/src/torchmetrics/functional/regression/kendall.py +++ b/src/torchmetrics/functional/regression/kendall.py @@ -26,8 +26,8 @@ class _MetricVariant(EnumStr): """Enumerate for metric variants.""" - @property - def _name(self) -> str: + @staticmethod + def _name() -> str: return "variant" A = "a" @@ -38,8 +38,8 @@ def _name(self) -> str: class _TestAlternative(EnumStr): """Enumerate for test altenative options.""" - @property - def _name(self) -> str: + @staticmethod + def _name() -> str: return "alternative" TWO_SIDED = "two-sided" diff --git a/src/torchmetrics/functional/text/infolm.py b/src/torchmetrics/functional/text/infolm.py index 7c31a167508..5b9088dd4f3 100644 --- a/src/torchmetrics/functional/text/infolm.py +++ b/src/torchmetrics/functional/text/infolm.py @@ -54,8 +54,8 @@ class _IMEnum(EnumStr): """A helper Enum class for storing the information measure.""" - @property - def _name(self) -> str: + @staticmethod + def _name() -> str: return "Information measure" KL_DIVERGENCE = "kl_divergence" diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index 9acaaecd35d..9067239f9bd 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -17,24 +17,33 @@ class EnumStr(StrEnum): - @property - def _name(self) -> str: + @staticmethod + def _name() -> str: return "Task" @classmethod - def from_str(cls, value: str) -> Optional["EnumStr"]: + def from_str(cls, value: str) -> "EnumStr": """Load from string. Raises: ValueError: If required information measure is not among the supported options. - """ - _allowed_im = [im.lower() for im in cls._member_names_] + >>> class MyEnum(EnumStr): + ... a = "aaa" + ... b = "bbb" + >>> MyEnum.from_str("a") + + >>> MyEnum.from_str("c") + Traceback (most recent call last): + ... + ValueError: Invalid Task: expected one of ['a', 'b'], but got c. + """ enum_key = super().from_str(value) - if enum_key is not None and enum_key in _allowed_im: + if enum_key is not None: return enum_key - raise ValueError(f"Invalid {cls._name}: expected one of {_allowed_im}, but got {enum_key}.") + _allowed_im = [m.lower() for m in cls._member_names_] + raise ValueError(f"Invalid {cls._name()}: expected one of {_allowed_im}, but got {value}.") class DataType(EnumStr): @@ -44,8 +53,8 @@ class DataType(EnumStr): True """ - @property - def _name(self) -> str: + @staticmethod + def _name() -> str: return "Data type" BINARY = "binary" @@ -65,8 +74,8 @@ class AverageMethod(EnumStr): True """ - @property - def _name(self) -> str: + @staticmethod + def _name() -> str: return "Average method" MICRO = "micro" @@ -79,8 +88,8 @@ def _name(self) -> str: class MDMCAverageMethod(EnumStr): """Enum to represent multi-dim multi-class average method.""" - @property - def _name(self) -> str: + @staticmethod + def _name() -> str: return "MDMC Average method" GLOBAL = "global" @@ -94,8 +103,8 @@ class ClassificationTask(EnumStr): True """ - @property - def _name(self) -> str: + @staticmethod + def _name() -> str: return "Classification" BINARY = "binary" From 01b693c6ff12a471b7ce2a7fb6e17d90e7658224 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 6 Feb 2023 14:27:51 +0100 Subject: [PATCH 15/20] remove valueerror + add from_str eval --- src/torchmetrics/classification/accuracy.py | 4 +-- src/torchmetrics/classification/auroc.py | 4 +-- .../classification/average_precision.py | 4 +-- .../classification/calibration_error.py | 10 +++---- .../classification/cohen_kappa.py | 10 +++---- .../classification/confusion_matrix.py | 4 +-- .../classification/exact_match.py | 8 ++--- src/torchmetrics/classification/f_beta.py | 4 +-- src/torchmetrics/classification/hamming.py | 5 +--- src/torchmetrics/classification/hinge.py | 10 +++---- src/torchmetrics/classification/jaccard.py | 4 +-- .../classification/matthews_corrcoef.py | 4 +-- .../classification/precision_recall.py | 8 ++--- .../classification/precision_recall_curve.py | 4 +-- .../recall_at_fixed_precision.py | 4 +-- src/torchmetrics/classification/roc.py | 4 +-- .../classification/specificity.py | 4 +-- .../specificity_at_sensitivity.py | 4 +-- .../classification/stat_scores.py | 4 +-- .../functional/classification/accuracy.py | 4 +-- .../functional/classification/auroc.py | 4 +-- .../classification/average_precision.py | 4 +-- .../classification/calibration_error.py | 7 +++-- .../functional/classification/cohen_kappa.py | 8 ++--- .../classification/confusion_matrix.py | 4 +-- .../functional/classification/exact_match.py | 8 ++--- .../functional/classification/f_beta.py | 8 ++--- .../functional/classification/hamming.py | 4 +-- .../functional/classification/hinge.py | 8 ++--- .../functional/classification/jaccard.py | 4 +-- .../classification/matthews_corrcoef.py | 4 +-- .../classification/precision_recall.py | 4 +-- .../classification/precision_recall_curve.py | 4 +-- .../recall_at_fixed_precision.py | 4 +-- .../functional/classification/roc.py | 4 +-- .../functional/classification/specificity.py | 4 +-- .../specificity_at_sensitivity.py | 4 +-- .../functional/classification/stat_scores.py | 4 +-- src/torchmetrics/utilities/enums.py | 30 +++++++++++++++++++ 39 files changed, 94 insertions(+), 134 deletions(-) diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index a121bf071ca..53900438bb5 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -445,6 +445,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) @@ -457,6 +458,3 @@ def __new__( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelAccuracy(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index f203714db22..0405f21c97a 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -353,6 +353,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: return BinaryAUROC(max_fpr, **kwargs) @@ -362,6 +363,3 @@ def __new__( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelAUROC(num_labels, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index ec8c6159d18..1f11703b8a7 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -357,6 +357,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: return BinaryAveragePrecision(**kwargs) @@ -366,6 +367,3 @@ def __new__( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelAveragePrecision(num_labels, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index c8849dc479d..d3385ea7b3b 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -29,7 +29,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat -from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel class BinaryCalibrationError(Metric): @@ -268,12 +268,10 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTaskNoMultilabel.from_str(task) kwargs.update({"n_bins": n_bins, "norm": norm, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == ClassificationTask.BINARY: + if task == ClassificationTaskNoMultilabel.BINARY: return BinaryCalibrationError(**kwargs) - if task == ClassificationTask.MULTICLASS: + if task == ClassificationTaskNoMultilabel.MULTICLASS: assert isinstance(num_classes, int) return MulticlassCalibrationError(num_classes, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index b8cd2e33ec2..1d184522d86 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -23,7 +23,7 @@ _multiclass_cohen_kappa_arg_validation, ) from torchmetrics.metric import Metric -from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel class BinaryCohenKappa(BinaryConfusionMatrix): @@ -222,12 +222,10 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTaskNoMultilabel.from_str(task) kwargs.update({"weights": weights, "ignore_index": ignore_index, "validate_args": validate_args}) - if task == ClassificationTask.BINARY: + if task == ClassificationTaskNoMultilabel.BINARY: return BinaryCohenKappa(threshold, **kwargs) - if task == ClassificationTask.MULTICLASS: + if task == ClassificationTaskNoMultilabel.MULTICLASS: assert isinstance(num_classes, int) return MulticlassCohenKappa(num_classes, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index d6155442675..b4277e2298a 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -398,6 +398,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update({"normalize": normalize, "ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: return BinaryConfusionMatrix(threshold, **kwargs) @@ -407,6 +408,3 @@ def __new__( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelConfusionMatrix(num_labels, threshold, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index dd03009da35..bcf75dbeafc 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -32,7 +32,7 @@ ) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat -from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.enums import ClassificationTaskNoBinary class MulticlassExactMatch(Metric): @@ -289,13 +289,13 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTaskNoBinary.from_str(task) kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) - if task == ClassificationTask.MULTICLASS: + if task == ClassificationTaskNoBinary.MULTICLASS: assert isinstance(num_classes, int) return MulticlassExactMatch(num_classes, **kwargs) - if task == ClassificationTask.MULTILABEL: + if task == ClassificationTaskNoBinary.MULTILABEL: assert isinstance(num_labels, int) return MultilabelExactMatch(num_labels, threshold, **kwargs) - raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}") diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index b381cfc077f..572ad691505 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -777,6 +777,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} @@ -790,6 +791,3 @@ def __new__( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelF1Score(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index 16f65eee3af..251f16e5dd4 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -345,7 +345,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: - + task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} @@ -359,6 +359,3 @@ def __new__( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelHammingDistance(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index 084048cab9a..e1896a64f1c 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -29,7 +29,7 @@ _multiclass_hinge_loss_update, ) from torchmetrics.metric import Metric -from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel class BinaryHingeLoss(Metric): @@ -253,12 +253,10 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTaskNoMultilabel.from_str(task) kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) - if task == ClassificationTask.BINARY: + if task == ClassificationTaskNoMultilabel.BINARY: return BinaryHingeLoss(squared, **kwargs) - if task == ClassificationTask.MULTICLASS: + if task == ClassificationTaskNoMultilabel.MULTICLASS: assert isinstance(num_classes, int) return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 6d0271cd9b3..64a44519475 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -296,6 +296,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: return BinaryJaccardIndex(threshold, **kwargs) @@ -305,6 +306,3 @@ def __new__( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index 85bf524a8ee..a776f17fb8a 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -238,6 +238,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: return BinaryMatthewsCorrCoef(threshold, **kwargs) @@ -247,6 +248,3 @@ def __new__( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 2f961481d41..a3341b8dd90 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -621,6 +621,7 @@ def __new__( kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} ) + task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return BinaryPrecision(threshold, **kwargs) if task == ClassificationTask.MULTICLASS: @@ -630,9 +631,6 @@ def __new__( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelPrecision(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) class Recall: @@ -673,6 +671,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} @@ -686,6 +685,3 @@ def __new__( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelRecall(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index e2b3832d5c0..0bdcd73af47 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -467,6 +467,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: return BinaryPrecisionRecallCurve(**kwargs) @@ -476,6 +477,3 @@ def __new__( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelPrecisionRecallCurve(num_labels, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/recall_at_fixed_precision.py b/src/torchmetrics/classification/recall_at_fixed_precision.py index eab8d2d747d..3acc2ec6ba6 100644 --- a/src/torchmetrics/classification/recall_at_fixed_precision.py +++ b/src/torchmetrics/classification/recall_at_fixed_precision.py @@ -324,6 +324,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return BinaryRecallAtFixedPrecision(min_precision, thresholds, ignore_index, validate_args, **kwargs) if task == ClassificationTask.MULTICLASS: @@ -336,6 +337,3 @@ def __new__( return MultilabelRecallAtFixedPrecision( num_labels, min_precision, thresholds, ignore_index, validate_args, **kwargs ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index ca15425bff0..fa92dd8de15 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -382,6 +382,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args}) if task == ClassificationTask.BINARY: return BinaryROC(**kwargs) @@ -391,6 +392,3 @@ def __new__( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelROC(num_labels, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index c3c1a10cd2d..9fbfa7df8a1 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -321,6 +321,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} @@ -334,6 +335,3 @@ def __new__( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelSpecificity(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/specificity_at_sensitivity.py b/src/torchmetrics/classification/specificity_at_sensitivity.py index d1c4c9effa7..a595e5ddd21 100644 --- a/src/torchmetrics/classification/specificity_at_sensitivity.py +++ b/src/torchmetrics/classification/specificity_at_sensitivity.py @@ -328,6 +328,7 @@ def __new__( # type: ignore validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return BinarySpecificityAtSensitivity(min_sensitivity, thresholds, ignore_index, validate_args, **kwargs) if task == ClassificationTask.MULTICLASS: @@ -340,6 +341,3 @@ def __new__( # type: ignore return MultilabelSpecificityAtSensitivity( num_labels, min_sensitivity, thresholds, ignore_index, validate_args, **kwargs ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index eb035b93fd6..a38321e667f 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -496,6 +496,7 @@ def __new__( validate_args: bool = True, **kwargs: Any, ) -> Metric: + task = ClassificationTask.from_str(task) assert multidim_average is not None kwargs.update( {"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args} @@ -509,6 +510,3 @@ def __new__( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return MultilabelStatScores(num_labels, threshold, average, **kwargs) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index de562eb6196..ea783378fa2 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -397,6 +397,7 @@ def accuracy( >>> accuracy(preds, target, task="multiclass", num_classes=3, top_k=2) tensor(0.6667) """ + task = ClassificationTask.from_str(task) assert multidim_average is not None if task == ClassificationTask.BINARY: return binary_accuracy(preds, target, threshold, multidim_average, ignore_index, validate_args) @@ -411,6 +412,3 @@ def accuracy( return multilabel_accuracy( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index b504f15db17..35e7a4cfebb 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -451,6 +451,7 @@ def auroc( >>> auroc(preds, target, task='multiclass', num_classes=3) tensor(0.7778) """ + task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return binary_auroc(preds, target, max_fpr, thresholds, ignore_index, validate_args) if task == ClassificationTask.MULTICLASS: @@ -459,6 +460,3 @@ def auroc( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_auroc(preds, target, num_labels, average, thresholds, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index a2af70e1b1d..c976e43fa93 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -438,6 +438,7 @@ def average_precision( >>> average_precision(pred, target, task="multiclass", num_classes=5, average=None) tensor([1.0000, 1.0000, 0.2500, 0.2500, nan]) """ + task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return binary_average_precision(preds, target, thresholds, ignore_index, validate_args) if task == ClassificationTask.MULTICLASS: @@ -448,6 +449,3 @@ def average_precision( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_average_precision(preds, target, num_labels, average, thresholds, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index 1e0e233d0f6..6ee2f33f4a1 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -23,7 +23,7 @@ _multiclass_confusion_matrix_format, _multiclass_confusion_matrix_tensor_validation, ) -from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel def _binning_bucketize( @@ -348,10 +348,11 @@ def calibration_error( :func:`binary_calibration_error` and :func:`multiclass_calibration_error` for the specific details of each argument influence and examples. """ + task = ClassificationTaskNoMultilabel.from_str(task) assert norm is not None - if task == ClassificationTask.BINARY: + if task == ClassificationTaskNoMultilabel.BINARY: return binary_calibration_error(preds, target, n_bins, norm, ignore_index, validate_args) - if task == ClassificationTask.MULTICLASS: + if task == ClassificationTaskNoMultilabel.MULTICLASS: assert isinstance(num_classes, int) return multiclass_calibration_error(preds, target, num_classes, n_bins, norm, ignore_index, validate_args) raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multiclass'` but got {task}") diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index 39b01480727..017f487498d 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -27,7 +27,7 @@ _multiclass_confusion_matrix_tensor_validation, _multiclass_confusion_matrix_update, ) -from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel def _cohen_kappa_reduce(confmat: Tensor, weights: Optional[Literal["linear", "quadratic", "none"]] = None) -> Tensor: @@ -257,9 +257,9 @@ class labels. >>> cohen_kappa(preds, target, task="multiclass", num_classes=2) tensor(0.5000) """ - if task == ClassificationTask.BINARY: + task = ClassificationTaskNoMultilabel.from_str(task) + if task == ClassificationTaskNoMultilabel.BINARY: return binary_cohen_kappa(preds, target, threshold, weights, ignore_index, validate_args) - if task == ClassificationTask.MULTICLASS: + if task == ClassificationTaskNoMultilabel.MULTICLASS: assert isinstance(num_classes, int) return multiclass_cohen_kappa(preds, target, num_classes, weights, ignore_index, validate_args) - raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multiclass'` but got {task}") diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 07acd82285b..3774b289d19 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -631,6 +631,7 @@ def confusion_matrix( [[1, 0], [1, 0]], [[0, 1], [0, 1]]]) """ + task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return binary_confusion_matrix(preds, target, threshold, normalize, ignore_index, validate_args) if task == ClassificationTask.MULTICLASS: @@ -639,6 +640,3 @@ def confusion_matrix( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_confusion_matrix(preds, target, num_labels, threshold, normalize, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/exact_match.py b/src/torchmetrics/functional/classification/exact_match.py index 76d8d17ea4c..c66c180cd93 100644 --- a/src/torchmetrics/functional/classification/exact_match.py +++ b/src/torchmetrics/functional/classification/exact_match.py @@ -26,7 +26,7 @@ _multilabel_stat_scores_tensor_validation, ) from torchmetrics.utilities.compute import _safe_divide -from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.enums import ClassificationTaskNoBinary def _exact_match_reduce( @@ -230,12 +230,12 @@ def exact_match( >>> exact_match(preds, target, task="multiclass", num_classes=3, multidim_average='samplewise') tensor([1., 0.]) """ - if task == ClassificationTask.MULTICLASS: + task = ClassificationTaskNoBinary.from_str(task) + if task == ClassificationTaskNoBinary.MULTICLASS: assert num_classes is not None return multiclass_exact_match(preds, target, num_classes, multidim_average, ignore_index, validate_args) - if task == "multilabel": + if task == ClassificationTaskNoBinary.MULTILABEL: assert num_labels is not None return multilabel_exact_match( preds, target, num_labels, threshold, multidim_average, ignore_index, validate_args ) - raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}") diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index 808662e739c..e66822e4055 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -693,6 +693,7 @@ def fbeta_score( >>> fbeta_score(preds, target, task="multiclass", num_classes=3, beta=0.5) tensor(0.3333) """ + task = ClassificationTask.from_str(task) assert multidim_average is not None if task == ClassificationTask.BINARY: return binary_fbeta_score(preds, target, beta, threshold, multidim_average, ignore_index, validate_args) @@ -707,9 +708,6 @@ def fbeta_score( return multilabel_fbeta_score( preds, target, beta, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) def f1_score( @@ -742,6 +740,7 @@ def f1_score( >>> f1_score(preds, target, task="multiclass", num_classes=3) tensor(0.3333) """ + task = ClassificationTask.from_str(task) assert multidim_average is not None if task == ClassificationTask.BINARY: return binary_f1_score(preds, target, threshold, multidim_average, ignore_index, validate_args) @@ -756,6 +755,3 @@ def f1_score( return multilabel_f1_score( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/hamming.py b/src/torchmetrics/functional/classification/hamming.py index 35fa99f0dbb..2d2f36c728c 100644 --- a/src/torchmetrics/functional/classification/hamming.py +++ b/src/torchmetrics/functional/classification/hamming.py @@ -399,6 +399,7 @@ def hamming_distance( >>> hamming_distance(preds, target, task="binary") tensor(0.2500) """ + task = ClassificationTask.from_str(task) assert multidim_average is not None if task == ClassificationTask.BINARY: return binary_hamming_distance(preds, target, threshold, multidim_average, ignore_index, validate_args) @@ -413,6 +414,3 @@ def hamming_distance( return multilabel_hamming_distance( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index 5c63ef0e0de..2dd69ead792 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -24,7 +24,7 @@ _multiclass_confusion_matrix_tensor_validation, ) from torchmetrics.utilities.data import to_onehot -from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.enums import ClassificationTaskNoMultilabel def _hinge_loss_compute(measure: Tensor, total: Tensor) -> Tensor: @@ -277,9 +277,9 @@ def hinge_loss( >>> hinge_loss(preds, target, task="multiclass", num_classes=3, multiclass_mode="one-vs-all") tensor([1.3743, 1.1945, 1.2359]) """ - if task == ClassificationTask.BINARY: + task = ClassificationTaskNoMultilabel.from_str(task) + if task == ClassificationTaskNoMultilabel.BINARY: return binary_hinge_loss(preds, target, squared, ignore_index, validate_args) - if task == ClassificationTask.MULTICLASS: + if task == ClassificationTaskNoMultilabel.MULTICLASS: assert isinstance(num_classes, int) return multiclass_hinge_loss(preds, target, num_classes, squared, multiclass_mode, ignore_index, validate_args) - raise ValueError(f"Expected argument `task` to either be `'binary'` or `'multilabel'` but got {task}") diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 39fbb10b81f..b339a49b7f9 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -322,6 +322,7 @@ def jaccard_index( >>> jaccard_index(pred, target, task="multiclass", num_classes=2) tensor(0.9660) """ + task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return binary_jaccard_index(preds, target, threshold, ignore_index, validate_args) if task == ClassificationTask.MULTICLASS: @@ -330,6 +331,3 @@ def jaccard_index( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_jaccard_index(preds, target, num_labels, threshold, average, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index 6d86d938ae3..8094a3c2ed1 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -236,6 +236,7 @@ def matthews_corrcoef( >>> matthews_corrcoef(preds, target, task="multiclass", num_classes=2) tensor(0.5774) """ + task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return binary_matthews_corrcoef(preds, target, threshold, ignore_index, validate_args) if task == ClassificationTask.MULTICLASS: @@ -244,6 +245,3 @@ def matthews_corrcoef( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_matthews_corrcoef(preds, target, num_labels, threshold, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index a9e719b2874..7f52fdfd67c 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -705,6 +705,7 @@ def recall( >>> recall(preds, target, task="multiclass", average='micro', num_classes=3) tensor(0.2500) """ + task = ClassificationTask.from_str(task) assert multidim_average is not None if task == ClassificationTask.BINARY: return binary_recall(preds, target, threshold, multidim_average, ignore_index, validate_args) @@ -719,6 +720,3 @@ def recall( return multilabel_recall( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index ee452bb16fe..c0bb03c45c4 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -816,6 +816,7 @@ def precision_recall_curve( >>> thresholds [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] """ + task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return binary_precision_recall_curve(preds, target, thresholds, ignore_index, validate_args) if task == ClassificationTask.MULTICLASS: @@ -824,6 +825,3 @@ def precision_recall_curve( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_precision_recall_curve(preds, target, num_labels, thresholds, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/recall_at_fixed_precision.py b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py index 1324d99a826..83092ebd759 100644 --- a/src/torchmetrics/functional/classification/recall_at_fixed_precision.py +++ b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py @@ -385,6 +385,7 @@ def recall_at_fixed_precision( :func:`binary_recall_at_fixed_precision`, :func:`multiclass_recall_at_fixed_precision` and :func:`multilabel_recall_at_fixed_precision` for the specific details of each argument influence and examples. """ + task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return binary_recall_at_fixed_precision(preds, target, min_precision, thresholds, ignore_index, validate_args) if task == ClassificationTask.MULTICLASS: @@ -397,6 +398,3 @@ def recall_at_fixed_precision( return multilabel_recall_at_fixed_precision( preds, target, num_labels, min_precision, thresholds, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index cd3535bf04f..da9751d684f 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -484,6 +484,7 @@ def roc( tensor([1.0000, 0.7576, 0.3680, 0.3468, 0.0745]), tensor([1.0000, 0.1837, 0.1338, 0.1183, 0.1138])] """ + task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return binary_roc(preds, target, thresholds, ignore_index, validate_args) if task == ClassificationTask.MULTICLASS: @@ -492,6 +493,3 @@ def roc( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/specificity.py b/src/torchmetrics/functional/classification/specificity.py index 12a02e26a67..8243e308ca9 100644 --- a/src/torchmetrics/functional/classification/specificity.py +++ b/src/torchmetrics/functional/classification/specificity.py @@ -370,6 +370,7 @@ def specificity( >>> specificity(preds, target, task="multiclass", average='micro', num_classes=3) tensor(0.6250) """ + task = ClassificationTask.from_str(task) assert multidim_average is not None if task == ClassificationTask.BINARY: return binary_specificity(preds, target, threshold, multidim_average, ignore_index, validate_args) @@ -384,6 +385,3 @@ def specificity( return multilabel_specificity( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/specificity_at_sensitivity.py b/src/torchmetrics/functional/classification/specificity_at_sensitivity.py index fa59907aaec..79aabe49c80 100644 --- a/src/torchmetrics/functional/classification/specificity_at_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_at_sensitivity.py @@ -414,6 +414,7 @@ def specicity_at_sensitivity( :func:`binary_specificity_at_sensitivity`, :func:`multiclass_specicity_at_sensitivity` and :func:`multilabel_specifity_at_sensitvity` for the specific details of each argument influence and examples. """ + task = ClassificationTask.from_str(task) if task == ClassificationTask.BINARY: return binary_specificity_at_sensitivity( # type: ignore preds, target, min_sensitivity, thresholds, ignore_index, validate_args @@ -428,6 +429,3 @@ def specicity_at_sensitivity( return multilabel_specificity_at_sensitivity( # type: ignore preds, target, num_labels, min_sensitivity, thresholds, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 328c4cf9a28..e071eef531c 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -1081,6 +1081,7 @@ def stat_scores( [1, 1, 1, 1, 2], [1, 0, 3, 0, 1]]) """ + task = ClassificationTask.from_str(task) assert multidim_average is not None if task == ClassificationTask.BINARY: return binary_stat_scores(preds, target, threshold, multidim_average, ignore_index, validate_args) @@ -1095,6 +1096,3 @@ def stat_scores( return multilabel_stat_scores( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) - raise ValueError( - f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" - ) diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index 9067239f9bd..815fcff5aba 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -110,3 +110,33 @@ def _name() -> str: BINARY = "binary" MULTICLASS = "multiclass" MULTILABEL = "multilabel" + + +class ClassificationTaskNoBinary(EnumStr): + """Enum to represent the different tasks in classification metrics. + + >>> "binary" in list(ClassificationTask) + False + """ + + @staticmethod + def _name() -> str: + return "Classification" + + MULTILABEL = "multilabel" + MULTICLASS = "multiclass" + + +class ClassificationTaskNoMultilabel(EnumStr): + """Enum to represent the different tasks in classification metrics. + + >>> "multilabel" in list(ClassificationTask) + False + """ + + @staticmethod + def _name() -> str: + return "Classification" + + BINARY = "binary" + MULTICLASS = "multiclass" From c62dc6418feab4b03f7a74ae9a326034f049ef89 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 6 Feb 2023 17:17:56 +0100 Subject: [PATCH 16/20] doctests --- src/torchmetrics/utilities/enums.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index 57651edcdf7..8d9bdd2ad8d 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -115,7 +115,7 @@ def _name() -> str: class ClassificationTaskNoBinary(EnumStr): """Enum to represent the different tasks in classification metrics. - >>> "binary" in list(ClassificationTask) + >>> "binary" in list(ClassificationTaskNoBinary) False """ @@ -130,7 +130,7 @@ def _name() -> str: class ClassificationTaskNoMultilabel(EnumStr): """Enum to represent the different tasks in classification metrics. - >>> "multilabel" in list(ClassificationTask) + >>> "multilabel" in list(ClassificationTaskNoMultilabel) False """ From 940dd6edfe77d77256d8236ec0dfc011393e0d2b Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 7 Feb 2023 03:08:08 +0100 Subject: [PATCH 17/20] docs --- src/torchmetrics/utilities/enums.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index 8d9bdd2ad8d..016c71f9c3c 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -27,7 +27,7 @@ def from_str(cls, value: str) -> "EnumStr": Raises: ValueError: - If required information measure is not among the supported options. + If required value is not among the supported options. >>> class MyEnum(EnumStr): ... a = "aaa" From 76dc87261326c9b43227b295e27233bcac4eb5c4 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 7 Feb 2023 03:52:01 +0100 Subject: [PATCH 18/20] - --- src/torchmetrics/utilities/enums.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/utilities/enums.py b/src/torchmetrics/utilities/enums.py index 016c71f9c3c..f2327854165 100644 --- a/src/torchmetrics/utilities/enums.py +++ b/src/torchmetrics/utilities/enums.py @@ -39,7 +39,7 @@ def from_str(cls, value: str) -> "EnumStr": ... ValueError: Invalid Task: expected one of ['a', 'b'], but got c. """ - enum_key = super().from_str(value) + enum_key = super().from_str(value.replace("-", "_")) if enum_key is not None: return enum_key _allowed_im = [m.lower() for m in cls._member_names_] From aca9f389feabe66c5a89950d7d5d91760d098e50 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 7 Feb 2023 08:47:50 +0100 Subject: [PATCH 19/20] mypy --- src/torchmetrics/functional/classification/accuracy.py | 1 + src/torchmetrics/functional/classification/cohen_kappa.py | 1 + src/torchmetrics/functional/classification/exact_match.py | 1 + src/torchmetrics/functional/classification/hamming.py | 1 + src/torchmetrics/functional/classification/hinge.py | 1 + src/torchmetrics/functional/classification/jaccard.py | 1 + .../functional/classification/matthews_corrcoef.py | 3 ++- .../functional/classification/precision_recall.py | 1 + src/torchmetrics/functional/classification/specificity.py | 1 + .../functional/classification/specificity_at_sensitivity.py | 1 + src/torchmetrics/functional/multimodal/clip_score.py | 4 +--- src/torchmetrics/functional/text/infolm.py | 2 +- 12 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index 1e3ee1fb07f..109ebf35c41 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -412,3 +412,4 @@ def accuracy( return multilabel_accuracy( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index 86b4ac8f600..3a771858ae1 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -263,3 +263,4 @@ class labels. if task == ClassificationTaskNoMultilabel.MULTICLASS: assert isinstance(num_classes, int) return multiclass_cohen_kappa(preds, target, num_classes, weights, ignore_index, validate_args) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/exact_match.py b/src/torchmetrics/functional/classification/exact_match.py index afc7cef8f09..da41cb8c016 100644 --- a/src/torchmetrics/functional/classification/exact_match.py +++ b/src/torchmetrics/functional/classification/exact_match.py @@ -239,3 +239,4 @@ def exact_match( return multilabel_exact_match( preds, target, num_labels, threshold, multidim_average, ignore_index, validate_args ) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/hamming.py b/src/torchmetrics/functional/classification/hamming.py index 3ea3380c6e2..e6c75379b5e 100644 --- a/src/torchmetrics/functional/classification/hamming.py +++ b/src/torchmetrics/functional/classification/hamming.py @@ -414,3 +414,4 @@ def hamming_distance( return multilabel_hamming_distance( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index faf470cb619..72cbdd85fb2 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -283,3 +283,4 @@ def hinge_loss( if task == ClassificationTaskNoMultilabel.MULTICLASS: assert isinstance(num_classes, int) return multiclass_hinge_loss(preds, target, num_classes, squared, multiclass_mode, ignore_index, validate_args) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index c0caeb2758d..203dccb0b09 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -331,3 +331,4 @@ def jaccard_index( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_jaccard_index(preds, target, num_labels, threshold, average, ignore_index, validate_args) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index f71dbe62c2c..999746f11a4 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -214,7 +214,7 @@ def multilabel_matthews_corrcoef( def matthews_corrcoef( preds: Tensor, target: Tensor, - task: Literal["binary", "multiclass", "multilabel"] = None, + task: Literal["binary", "multiclass", "multilabel"], threshold: float = 0.5, num_classes: Optional[int] = None, num_labels: Optional[int] = None, @@ -245,3 +245,4 @@ def matthews_corrcoef( if task == ClassificationTask.MULTILABEL: assert isinstance(num_labels, int) return multilabel_matthews_corrcoef(preds, target, num_labels, threshold, ignore_index, validate_args) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index ab85189ba10..c131b2169b1 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -720,3 +720,4 @@ def recall( return multilabel_recall( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/specificity.py b/src/torchmetrics/functional/classification/specificity.py index f37d25a3613..fa5f2b11567 100644 --- a/src/torchmetrics/functional/classification/specificity.py +++ b/src/torchmetrics/functional/classification/specificity.py @@ -385,3 +385,4 @@ def specificity( return multilabel_specificity( preds, target, num_labels, threshold, average, multidim_average, ignore_index, validate_args ) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/classification/specificity_at_sensitivity.py b/src/torchmetrics/functional/classification/specificity_at_sensitivity.py index 96962216f05..a97fcfa0a56 100644 --- a/src/torchmetrics/functional/classification/specificity_at_sensitivity.py +++ b/src/torchmetrics/functional/classification/specificity_at_sensitivity.py @@ -429,3 +429,4 @@ def specicity_at_sensitivity( return multilabel_specificity_at_sensitivity( # type: ignore preds, target, num_labels, min_sensitivity, thresholds, ignore_index, validate_args ) + raise ValueError(f"Not handled value: {task}") # this is for compliant of mypy diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index e5757c86619..4ee5c12c89f 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -51,9 +51,7 @@ def _clip_score_update( f"Expected the number of images and text examples to be the same but got {len(images)} and {len(text)}" ) device = images[0].device - processed_input = processor( - text=text, images=[i.cpu() for i in images], return_tensors="pt", padding=True - ) # type: ignore + processed_input = processor(text=text, images=[i.cpu() for i in images], return_tensors="pt", padding=True) img_features = model.get_image_features(processed_input["pixel_values"].to(device)) img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True) diff --git a/src/torchmetrics/functional/text/infolm.py b/src/torchmetrics/functional/text/infolm.py index 25033d00100..f9c9c65b045 100644 --- a/src/torchmetrics/functional/text/infolm.py +++ b/src/torchmetrics/functional/text/infolm.py @@ -384,7 +384,7 @@ def _get_batch_distribution( for mask_idx in range(seq_len): input_ids = batch["input_ids"].clone() input_ids[:, mask_idx] = special_tokens_map["mask_token_id"] - logits_distribution = model(input_ids, batch["attention_mask"]).logits # type: ignore + logits_distribution = model(input_ids, batch["attention_mask"]).logits # [batch_size, seq_len, vocab_size] -> [batch_size, vocab_size] logits_distribution = logits_distribution[:, mask_idx, :] prob_distribution = F.softmax(logits_distribution / temperature, dim=-1) From 70314f074694b9f75ec7267c3a8c343563f5da84 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 7 Feb 2023 08:56:41 +0100 Subject: [PATCH 20/20] mypy --- src/torchmetrics/functional/multimodal/clip_score.py | 4 +++- src/torchmetrics/functional/text/infolm.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index 4ee5c12c89f..e5757c86619 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -51,7 +51,9 @@ def _clip_score_update( f"Expected the number of images and text examples to be the same but got {len(images)} and {len(text)}" ) device = images[0].device - processed_input = processor(text=text, images=[i.cpu() for i in images], return_tensors="pt", padding=True) + processed_input = processor( + text=text, images=[i.cpu() for i in images], return_tensors="pt", padding=True + ) # type: ignore img_features = model.get_image_features(processed_input["pixel_values"].to(device)) img_features = img_features / img_features.norm(p=2, dim=-1, keepdim=True) diff --git a/src/torchmetrics/functional/text/infolm.py b/src/torchmetrics/functional/text/infolm.py index f9c9c65b045..25033d00100 100644 --- a/src/torchmetrics/functional/text/infolm.py +++ b/src/torchmetrics/functional/text/infolm.py @@ -384,7 +384,7 @@ def _get_batch_distribution( for mask_idx in range(seq_len): input_ids = batch["input_ids"].clone() input_ids[:, mask_idx] = special_tokens_map["mask_token_id"] - logits_distribution = model(input_ids, batch["attention_mask"]).logits + logits_distribution = model(input_ids, batch["attention_mask"]).logits # type: ignore # [batch_size, seq_len, vocab_size] -> [batch_size, vocab_size] logits_distribution = logits_distribution[:, mask_idx, :] prob_distribution = F.softmax(logits_distribution / temperature, dim=-1)