Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make num_classes optional, in case of micro averaging #2841

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
0644130
Make num_classes optional, in case of micro averaging
baskrahmer Nov 22, 2024
6d528af
Leftover from rebase
baskrahmer Nov 22, 2024
046ee9f
Merge branch 'master' into optional_num_classes
baskrahmer Nov 27, 2024
43ee11f
Adjust true-negative calculation for micro-average
baskrahmer Dec 1, 2024
d9f9ed8
Merge branch 'master' into optional_num_classes
baskrahmer Dec 6, 2024
34b78ca
Make `num_classes` optional in stat function
baskrahmer Dec 8, 2024
7b26648
ruff
baskrahmer Dec 8, 2024
577f34c
Types
baskrahmer Dec 8, 2024
f9861e2
Types
baskrahmer Dec 8, 2024
ec74350
Merge branch 'master' into optional_num_classes
baskrahmer Dec 11, 2024
ba38728
chlog
Borda Dec 17, 2024
da9984c
Merge branch 'master' into optional_num_classes
Borda Dec 21, 2024
a21457f
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 21, 2024
3bb36ff
Merge branch 'master' into optional_num_classes
Borda Dec 21, 2024
2048cb8
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 21, 2024
af613dc
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 21, 2024
96562e5
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 24, 2024
b5f160b
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 24, 2024
a5e9044
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 24, 2024
02792a4
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 25, 2024
9d832e4
Merge branch 'master' into optional_num_classes
mergify[bot] Dec 31, 2024
255db5f
Merge branch 'master' into optional_num_classes
Borda Jan 6, 2025
03f70ae
Merge branch 'master' into optional_num_classes
Borda Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchmetrics/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ class MulticlassStatScores(_AbstractStatScores):

def __init__(
self,
num_classes: int,
num_classes: Optional[int] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top_k: int = 1,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
multidim_average: Literal["global", "samplewise"] = "global",
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def binary_accuracy(
def multiclass_accuracy(
preds: Tensor,
target: Tensor,
num_classes: int,
num_classes: Optional[int] = None,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
top_k: int = 1,
multidim_average: Literal["global", "samplewise"] = "global",
Expand Down
18 changes: 9 additions & 9 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,15 +311,15 @@ def _multiclass_stat_scores_tensor_validation(
"Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)"
" and `preds` should be (N, C, ...)."
)

check_value = num_classes if ignore_index is None else num_classes + 1
for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
num_unique_values = len(torch.unique(t, dim=None))
if num_unique_values > check_value:
raise RuntimeError(
f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
f" {num_unique_values} in `{name}`. Found values: {torch.unique(t, dim=None)}."
)
if num_classes is not None:
check_value = num_classes if ignore_index is None else num_classes + 1
for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
num_unique_values = len(torch.unique(t, dim=None))
if num_unique_values > check_value:
raise RuntimeError(
f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
f" {num_unique_values} in `{name}`. Found values: {torch.unique(t, dim=None)}."
)


def _multiclass_stat_scores_format(
Expand Down
6 changes: 6 additions & 0 deletions tests/unittests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,12 @@ def test_corner_cases():
res = metric(preds, target)
assert res == 1.0

metric_micro1 = MulticlassAccuracy(num_classes=None, average="micro", ignore_index=0)
metric_micro2 = MulticlassAccuracy(num_classes=3, average="micro", ignore_index=0)
res1 = metric_micro1(preds, target)
res2 = metric_micro2(preds, target)
assert res1 == res2


@pytest.mark.parametrize(
("metric", "kwargs"),
Expand Down
Loading