Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make num_classes optional, in case of micro averaging #2841

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
0644130
Make num_classes optional, in case of micro averaging
baskrahmer Nov 22, 2024
6d528af
Leftover from rebase
baskrahmer Nov 22, 2024
046ee9f
Merge branch 'master' into optional_num_classes
baskrahmer Nov 27, 2024
43ee11f
Adjust true-negative calculation for micro-average
baskrahmer Dec 1, 2024
d9f9ed8
Merge branch 'master' into optional_num_classes
baskrahmer Dec 6, 2024
34b78ca
Make `num_classes` optional in stat function
baskrahmer Dec 8, 2024
7b26648
ruff
baskrahmer Dec 8, 2024
577f34c
Types
baskrahmer Dec 8, 2024
f9861e2
Types
baskrahmer Dec 8, 2024
ec74350
Merge branch 'master' into optional_num_classes
baskrahmer Dec 11, 2024
ba38728
chlog
Borda Dec 17, 2024
da9984c
Merge branch 'master' into optional_num_classes
Borda Dec 21, 2024
a21457f
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 21, 2024
3bb36ff
Merge branch 'master' into optional_num_classes
Borda Dec 21, 2024
2048cb8
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 21, 2024
af613dc
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 21, 2024
96562e5
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 24, 2024
b5f160b
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 24, 2024
a5e9044
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 24, 2024
02792a4
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 25, 2024
9d832e4
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 31, 2024
255db5f
Merge branch 'master' into optional_num_classes
Borda Jan 6, 2025
03f70ae
Merge branch 'master' into optional_num_classes
Borda Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Make `num_classes` optional for classification in case of micro averaging ([#2841](https://github.com/PyTorchLightning/metrics/pull/2841))


- Enabled specifying weights path for FID ([#2867](https://github.com/PyTorchLightning/metrics/pull/2867))


Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ class MulticlassStatScores(_AbstractStatScores):

def __init__(
self,
num_classes: int,
num_classes: Optional[int] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top_k: int = 1,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
multidim_average: Literal["global", "samplewise"] = "global",
Expand All @@ -330,7 +330,7 @@ def __init__(
self.zero_division = zero_division

self._create_state(
size=1 if (average == "micro" and top_k == 1) else num_classes, multidim_average=multidim_average
size=1 if (average == "micro" and top_k == 1) else (num_classes or 1), multidim_average=multidim_average
)

def update(self, preds: Tensor, target: Tensor) -> None:
Expand All @@ -340,8 +340,9 @@ def update(self, preds: Tensor, target: Tensor) -> None:
preds, target, self.num_classes, self.multidim_average, self.ignore_index
)
preds, target = _multiclass_stat_scores_format(preds, target, self.top_k)
num_classes = self.num_classes if self.num_classes is not None else 1
tp, fp, tn, fn = _multiclass_stat_scores_update(
preds, target, self.num_classes, self.top_k, self.average, self.multidim_average, self.ignore_index
preds, target, num_classes, self.top_k, self.average, self.multidim_average, self.ignore_index
)
self._update_state(tp, fp, tn, fn)

Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def binary_accuracy(
def multiclass_accuracy(
preds: Tensor,
target: Tensor,
num_classes: int,
num_classes: Optional[int] = None,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
top_k: int = 1,
multidim_average: Literal["global", "samplewise"] = "global",
Expand Down Expand Up @@ -266,7 +266,7 @@ def multiclass_accuracy(
_multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index)
preds, target = _multiclass_stat_scores_format(preds, target, top_k)
tp, fp, tn, fn = _multiclass_stat_scores_update(
preds, target, num_classes, top_k, average, multidim_average, ignore_index
preds, target, num_classes or 1, top_k, average, multidim_average, ignore_index
)
return _accuracy_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, top_k=top_k)

Expand Down
28 changes: 14 additions & 14 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def binary_stat_scores(


def _multiclass_stat_scores_arg_validation(
num_classes: int,
num_classes: Optional[int],
top_k: int = 1,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
multidim_average: Literal["global", "samplewise"] = "global",
Expand All @@ -235,11 +235,11 @@ def _multiclass_stat_scores_arg_validation(
- ``zero_division`` has to be 0 or 1

"""
if not isinstance(num_classes, int) or num_classes < 2:
if num_classes is not None and (not isinstance(num_classes, int) or num_classes < 2):
raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}")
if not isinstance(top_k, int) and top_k < 1:
raise ValueError(f"Expected argument `top_k` to be an integer larger than or equal to 1, but got {top_k}")
if top_k > num_classes:
if top_k > (num_classes if num_classes is not None else 1):
raise ValueError(
f"Expected argument `top_k` to be smaller or equal to `num_classes` but got {top_k} and {num_classes}"
)
Expand All @@ -260,7 +260,7 @@ def _multiclass_stat_scores_arg_validation(
def _multiclass_stat_scores_tensor_validation(
preds: Tensor,
target: Tensor,
num_classes: int,
num_classes: Optional[int],
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
) -> None:
Expand All @@ -278,7 +278,7 @@ def _multiclass_stat_scores_tensor_validation(
if preds.ndim == target.ndim + 1:
if not preds.is_floating_point():
raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.")
if preds.shape[1] != num_classes:
if num_classes is not None and preds.shape[1] != num_classes:
raise ValueError(
"If `preds` have one dimension more than `target`, `preds.shape[1]` should be"
" equal to number of classes."
Expand Down Expand Up @@ -310,15 +310,15 @@ def _multiclass_stat_scores_tensor_validation(
"Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)"
" and `preds` should be (N, C, ...)."
)

check_value = num_classes if ignore_index is None else num_classes + 1
for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
num_unique_values = len(torch.unique(t, dim=None))
if num_unique_values > check_value:
raise RuntimeError(
f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
f" {num_unique_values} in `{name}`. Found values: {torch.unique(t, dim=None)}."
)
if num_classes is not None:
check_value = num_classes if ignore_index is None else num_classes + 1
for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
num_unique_values = len(torch.unique(t, dim=None))
if num_unique_values > check_value:
raise RuntimeError(
f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
f" {num_unique_values} in `{name}`. Found values: {torch.unique(t, dim=None)}."
)


def _multiclass_stat_scores_format(
Expand Down
6 changes: 6 additions & 0 deletions tests/unittests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,12 @@ def test_corner_cases():
res = metric(preds, target)
assert res == 1.0

metric_micro1 = MulticlassAccuracy(num_classes=None, average="micro", ignore_index=0)
metric_micro2 = MulticlassAccuracy(num_classes=3, average="micro", ignore_index=0)
res1 = metric_micro1(preds, target)
res2 = metric_micro2(preds, target)
assert res1 == res2


@pytest.mark.parametrize(
("metric", "kwargs"),
Expand Down
Loading