Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ClassificationTask Enum #1479

Merged
merged 32 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
b7a8dc7
add enum
SkafteNicki Feb 3, 2023
aee0f12
add enum
SkafteNicki Feb 3, 2023
3bec023
gh: update templates (#1477)Co-authored-by: pre-commit-ci[bot] <66853…
Borda Feb 3, 2023
20b94d5
Merge branch 'master' into classification/enum
SkafteNicki Feb 3, 2023
0140d12
add enum
SkafteNicki Feb 3, 2023
9bd0063
add enum
SkafteNicki Feb 3, 2023
5a4caef
StrEnum
Borda Feb 3, 2023
169586f
utils 0.5.0
Borda Feb 5, 2023
2720af1
with error
Borda Feb 5, 2023
a6cd44f
links
Borda Feb 6, 2023
cbc794c
Merge branch 'master' into classification/enum
mergify[bot] Feb 6, 2023
fa0fa16
property
Borda Feb 6, 2023
64da325
_name
Borda Feb 6, 2023
f9682a0
Merge branch 'master' into classification/enum
Borda Feb 6, 2023
0ca463d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2023
7b81238
Merge branch 'master' into classification/enum
mergify[bot] Feb 6, 2023
5355483
chlog
Borda Feb 6, 2023
e2ab95d
Merge branch 'classification/enum' of https://github.com/PyTorchLight…
Borda Feb 6, 2023
fbcc0e5
docstring
Borda Feb 6, 2023
67aeb2f
Merge branch 'master' into classification/enum
mergify[bot] Feb 6, 2023
49d7d7b
Merge branch 'classification/enum' of https://github.com/PyTorchLight…
SkafteNicki Feb 6, 2023
01b693c
remove valueerror + add from_str eval
SkafteNicki Feb 6, 2023
ea8ecc0
Merge branch 'master' into classification/enum
mergify[bot] Feb 6, 2023
c62dc64
doctests
Borda Feb 6, 2023
33bd5a5
Merge branch 'master' into classification/enum
mergify[bot] Feb 7, 2023
940dd6e
docs
Borda Feb 7, 2023
8adf66f
Merge branch 'classification/enum' of https://github.com/PyTorchLight…
Borda Feb 7, 2023
5424c1f
Merge branch 'master' into classification/enum
mergify[bot] Feb 7, 2023
76dc872
-
Borda Feb 7, 2023
aca9f38
mypy
Borda Feb 7, 2023
70314f0
mypy
Borda Feb 7, 2023
2067b00
Merge branch 'master' into classification/enum
mergify[bot] Feb 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ Steps to reproduce the behavior...

<!-- If you have a code sample, error messages, stack traces, please provide it here as well -->

#### Code sample
<details>
<summary>Code sample</summary>

<!-- Ideally attach a minimal code sample to reproduce the decried issue.
Minimal means having the shortest code but still preserving the bug. -->

</details>

### Expected behavior

<!-- A clear and concise description of what you expected to happen. -->
Expand Down
5 changes: 2 additions & 3 deletions .github/ISSUE_TEMPLATE/documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ assignees: ''

For typos and doc fixes, please go ahead and:

1. Create an issue.
1. Fix the typo.
1. Submit a PR.
- For a simple typo or fix, please send directly a PR (no need to create an issue)
- If you are not sure about the proper solution, please describe here your finding...

Thanks!
12 changes: 9 additions & 3 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@

Fixes #\<issue_number>

## Before submitting
<details>
<summary>Before submitting</summary>

- [ ] Was this **discussed/approved** via a Github issue? (no need for typos and docs improvements)
- [ ] Was this **discussed/agreed** via a Github issue? (no need for typos and docs improvements)
- [ ] Did you read the [contributor guideline](https://github.com/Lightning-AI/metrics/blob/master/.github/CONTRIBUTING.md), Pull Request section?
- [ ] Did you make sure to **update the docs**?
- [ ] Did you write any new **necessary tests**?

## PR review
</details>

<details>
<summary>PR review</summary>

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

</details>

## Did you have fun?

Make sure you had fun coding 🙃
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from torchmetrics.functional.classification.accuracy import _accuracy_reduce
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val

Expand Down Expand Up @@ -453,13 +454,13 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryAccuracy(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassAccuracy(num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelAccuracy(num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask


class BinaryAUROC(BinaryPrecisionRecallCurve):
Expand Down Expand Up @@ -353,12 +354,12 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryAUROC(max_fpr, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassAUROC(num_classes, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelAUROC(num_labels, average, **kwargs)
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask


class BinaryAveragePrecision(BinaryPrecisionRecallCurve):
Expand Down Expand Up @@ -357,12 +358,12 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"thresholds": thresholds, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryAveragePrecision(**kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassAveragePrecision(num_classes, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelAveragePrecision(num_labels, average, **kwargs)
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/classification/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask


class BinaryCalibrationError(Metric):
Expand Down Expand Up @@ -268,9 +269,9 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"n_bins": n_bins, "norm": norm, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryCalibrationError(**kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassCalibrationError(num_classes, **kwargs)
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/classification/cohen_kappa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_multiclass_cohen_kappa_arg_validation,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryCohenKappa(BinaryConfusionMatrix):
Expand Down Expand Up @@ -222,9 +223,9 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"weights": weights, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryCohenKappa(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassCohenKappa(num_classes, **kwargs)
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_multilabel_confusion_matrix_update,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _PLOT_OUT_TYPE, plot_confusion_matrix

Expand Down Expand Up @@ -398,12 +399,12 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"normalize": normalize, "ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryConfusionMatrix(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassConfusionMatrix(num_classes, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelConfusionMatrix(num_labels, threshold, **kwargs)
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/classification/exact_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.enums import ClassificationTask


class MulticlassExactMatch(Metric):
Expand Down Expand Up @@ -291,10 +292,10 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassExactMatch(num_classes, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelExactMatch(num_labels, threshold, **kwargs)
raise ValueError(f"Expected argument `task` to either be `'multiclass'` or `'multilabel'` but got {task}")
13 changes: 7 additions & 6 deletions src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
_multilabel_fbeta_score_arg_validation,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryFBetaScore(BinaryStatScores):
Expand Down Expand Up @@ -729,13 +730,13 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryFBetaScore(beta, threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassFBetaScore(beta, num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down Expand Up @@ -780,13 +781,13 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryF1Score(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassF1Score(num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelF1Score(num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores
from torchmetrics.functional.classification.hamming import _hamming_distance_reduce
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryHammingDistance(BinaryStatScores):
Expand Down Expand Up @@ -349,13 +350,13 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryHammingDistance(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassHammingDistance(num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelHammingDistance(num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/classification/hinge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_multiclass_hinge_loss_update,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryHingeLoss(Metric):
Expand Down Expand Up @@ -253,9 +254,9 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryHingeLoss(squared, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs)
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/jaccard.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_multilabel_jaccard_index_arg_validation,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryJaccardIndex(BinaryConfusionMatrix):
Expand Down Expand Up @@ -296,12 +297,12 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryJaccardIndex(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassJaccardIndex(num_classes, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/matthews_corrcoef.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix
from torchmetrics.functional.classification.matthews_corrcoef import _matthews_corrcoef_reduce
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryMatthewsCorrCoef(BinaryConfusionMatrix):
Expand Down Expand Up @@ -238,12 +239,12 @@ def __new__(
**kwargs: Any,
) -> Metric:
kwargs.update({"ignore_index": ignore_index, "validate_args": validate_args})
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryMatthewsCorrCoef(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
return MulticlassMatthewsCorrCoef(num_classes, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs)
raise ValueError(
Expand Down
13 changes: 7 additions & 6 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores
from torchmetrics.functional.classification.precision_recall import _precision_recall_reduce
from torchmetrics.metric import Metric
from torchmetrics.utilities.enums import ClassificationTask


class BinaryPrecision(BinaryStatScores):
Expand Down Expand Up @@ -620,13 +621,13 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryPrecision(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassPrecision(num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelPrecision(num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down Expand Up @@ -676,13 +677,13 @@ def __new__(
kwargs.update(
{"multidim_average": multidim_average, "ignore_index": ignore_index, "validate_args": validate_args}
)
if task == "binary":
if task == ClassificationTask.BINARY:
return BinaryRecall(threshold, **kwargs)
if task == "multiclass":
if task == ClassificationTask.MULTICLASS:
assert isinstance(num_classes, int)
assert isinstance(top_k, int)
return MulticlassRecall(num_classes, top_k, average, **kwargs)
if task == "multilabel":
if task == ClassificationTask.MULTILABEL:
assert isinstance(num_labels, int)
return MultilabelRecall(num_labels, threshold, average, **kwargs)
raise ValueError(
Expand Down
Loading