Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix label checking in classification #2427

Merged
merged 9 commits into from
Mar 15, 2024
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug when `top_k>1` and `average="macro"` for classification metrics ([#2423](https://github.com/Lightning-AI/torchmetrics/pull/2423))


- Fixed case where label prediction tensors in classification metrics were not validated correctly ([#2427](https://github.com/Lightning-AI/torchmetrics/pull/2427))


- Fixed how auc scores are calculated in `PrecisionRecallCurve.plot` methods ([#2437](https://github.com/Lightning-AI/torchmetrics/pull/2437))

## [1.3.1] - 2024-02-12
Expand Down
20 changes: 6 additions & 14 deletions src/torchmetrics/functional/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,21 +285,13 @@ def _multiclass_confusion_matrix_tensor_validation(
" and `preds` should be (N, C, ...)."
)

num_unique_values = len(torch.unique(target))
check = num_unique_values > num_classes if ignore_index is None else num_unique_values > num_classes + 1
if check:
raise RuntimeError(
"Detected more unique values in `target` than `num_classes`. Expected only "
f"{num_classes if ignore_index is None else num_classes + 1} but found "
f"{num_unique_values} in `target`."
)

if not preds.is_floating_point():
num_unique_values = len(torch.unique(preds))
if num_unique_values > num_classes:
check_value = num_classes if ignore_index is None else num_classes + 1
for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
num_unique_values = len(torch.unique(t))
if num_unique_values > check_value:
raise RuntimeError(
"Detected more unique values in `preds` than `num_classes`. Expected only "
f"{num_classes} but found {num_unique_values} in `preds`."
f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
f" {num_unique_values} in `target`."
)


Expand Down
20 changes: 6 additions & 14 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,21 +304,13 @@ def _multiclass_stat_scores_tensor_validation(
" and `preds` should be (N, C, ...)."
)

num_unique_values = len(torch.unique(target))
check = num_unique_values > num_classes if ignore_index is None else num_unique_values > num_classes + 1
if check:
raise RuntimeError(
"Detected more unique values in `target` than `num_classes`. Expected only"
f" {num_classes if ignore_index is None else num_classes + 1} but found"
f" {num_unique_values} in `target`."
)

if not preds.is_floating_point():
unique_values = torch.unique(preds)
if len(unique_values) > num_classes:
check_value = num_classes if ignore_index is None else num_classes + 1
for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
num_unique_values = len(torch.unique(t))
if num_unique_values > check_value:
raise RuntimeError(
"Detected more unique values in `preds` than `num_classes`. Expected only"
f" {num_classes} but found {len(unique_values)} in `preds`."
f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
f" {num_unique_values} in `target`."
)


Expand Down
35 changes: 35 additions & 0 deletions tests/unittests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,41 @@ def test_multiclass_confusion_matrix_dtype_gpu(self, inputs, dtype):
)


@pytest.mark.parametrize(
("preds", "target", "ignore_index", "error_message"),
[
(
torch.randint(NUM_CLASSES + 1, (100,)),
torch.randint(NUM_CLASSES, (100,)),
None,
f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES}.*",
),
(
torch.randint(NUM_CLASSES, (100,)),
torch.randint(NUM_CLASSES + 1, (100,)),
None,
f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES}.*",
),
(
torch.randint(NUM_CLASSES + 2, (100,)),
torch.randint(NUM_CLASSES, (100,)),
1,
f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES + 1}.*",
),
(
torch.randint(NUM_CLASSES, (100,)),
torch.randint(NUM_CLASSES + 2, (100,)),
1,
f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES + 1}.*",
),
],
)
def test_raises_error_on_too_many_classes(preds, target, ignore_index, error_message):
"""Test that an error is raised if the number of classes in preds or target is larger than expected."""
with pytest.raises(RuntimeError, match=error_message):
multiclass_confusion_matrix(preds, target, num_classes=NUM_CLASSES, ignore_index=ignore_index)


def test_multiclass_overflow():
"""Test that multiclass computations does not overflow even on byte inputs."""
preds = torch.randint(20, (100,)).byte()
Expand Down
35 changes: 35 additions & 0 deletions tests/unittests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,41 @@ def test_multiclass_stat_scores_dtype_gpu(self, inputs, dtype):
)


@pytest.mark.parametrize(
("preds", "target", "ignore_index", "error_message"),
[
(
torch.randint(NUM_CLASSES + 1, (100,)),
torch.randint(NUM_CLASSES, (100,)),
None,
f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES}.*",
),
(
torch.randint(NUM_CLASSES, (100,)),
torch.randint(NUM_CLASSES + 1, (100,)),
None,
f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES}.*",
),
(
torch.randint(NUM_CLASSES + 2, (100,)),
torch.randint(NUM_CLASSES, (100,)),
1,
f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES + 1}.*",
),
(
torch.randint(NUM_CLASSES, (100,)),
torch.randint(NUM_CLASSES + 2, (100,)),
1,
f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES + 1}.*",
),
],
)
def test_raises_error_on_too_many_classes(preds, target, ignore_index, error_message):
"""Test that an error is raised if the number of classes in preds or target is larger than expected."""
with pytest.raises(RuntimeError, match=error_message):
multiclass_stat_scores(preds, target, num_classes=NUM_CLASSES, ignore_index=ignore_index)


_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])

Expand Down
Loading