diff --git a/CHANGELOG.md b/CHANGELOG.md index 583e5611730..d81fa86352c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed bug when `top_k>1` and `average="macro"` for classification metrics ([#2423](https://github.com/Lightning-AI/torchmetrics/pull/2423)) +- Fixed case where label prediction tensors in classification metrics were not validated correctly ([#2427](https://github.com/Lightning-AI/torchmetrics/pull/2427)) + + - Fixed how auc scores are calculated in `PrecisionRecallCurve.plot` methods ([#2437](https://github.com/Lightning-AI/torchmetrics/pull/2437)) ## [1.3.1] - 2024-02-12 diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index c51770ae7d6..b668c152e5b 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -285,21 +285,13 @@ def _multiclass_confusion_matrix_tensor_validation( " and `preds` should be (N, C, ...)." ) - num_unique_values = len(torch.unique(target)) - check = num_unique_values > num_classes if ignore_index is None else num_unique_values > num_classes + 1 - if check: - raise RuntimeError( - "Detected more unique values in `target` than `num_classes`. Expected only " - f"{num_classes if ignore_index is None else num_classes + 1} but found " - f"{num_unique_values} in `target`." - ) - - if not preds.is_floating_point(): - num_unique_values = len(torch.unique(preds)) - if num_unique_values > num_classes: + check_value = num_classes if ignore_index is None else num_classes + 1 + for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005 + num_unique_values = len(torch.unique(t)) + if num_unique_values > check_value: raise RuntimeError( - "Detected more unique values in `preds` than `num_classes`. Expected only " - f"{num_classes} but found {num_unique_values} in `preds`." + f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found" + f" {num_unique_values} in `target`." ) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 5153554253b..aa8e0bf5016 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -304,21 +304,13 @@ def _multiclass_stat_scores_tensor_validation( " and `preds` should be (N, C, ...)." ) - num_unique_values = len(torch.unique(target)) - check = num_unique_values > num_classes if ignore_index is None else num_unique_values > num_classes + 1 - if check: - raise RuntimeError( - "Detected more unique values in `target` than `num_classes`. Expected only" - f" {num_classes if ignore_index is None else num_classes + 1} but found" - f" {num_unique_values} in `target`." - ) - - if not preds.is_floating_point(): - unique_values = torch.unique(preds) - if len(unique_values) > num_classes: + check_value = num_classes if ignore_index is None else num_classes + 1 + for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005 + num_unique_values = len(torch.unique(t)) + if num_unique_values > check_value: raise RuntimeError( - "Detected more unique values in `preds` than `num_classes`. Expected only" - f" {num_classes} but found {len(unique_values)} in `preds`." + f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found" + f" {num_unique_values} in `target`." ) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 666e9f0fc05..12f21451949 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -239,6 +239,41 @@ def test_multiclass_confusion_matrix_dtype_gpu(self, inputs, dtype): ) +@pytest.mark.parametrize( + ("preds", "target", "ignore_index", "error_message"), + [ + ( + torch.randint(NUM_CLASSES + 1, (100,)), + torch.randint(NUM_CLASSES, (100,)), + None, + f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES}.*", + ), + ( + torch.randint(NUM_CLASSES, (100,)), + torch.randint(NUM_CLASSES + 1, (100,)), + None, + f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES}.*", + ), + ( + torch.randint(NUM_CLASSES + 2, (100,)), + torch.randint(NUM_CLASSES, (100,)), + 1, + f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES + 1}.*", + ), + ( + torch.randint(NUM_CLASSES, (100,)), + torch.randint(NUM_CLASSES + 2, (100,)), + 1, + f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES + 1}.*", + ), + ], +) +def test_raises_error_on_too_many_classes(preds, target, ignore_index, error_message): + """Test that an error is raised if the number of classes in preds or target is larger than expected.""" + with pytest.raises(RuntimeError, match=error_message): + multiclass_confusion_matrix(preds, target, num_classes=NUM_CLASSES, ignore_index=ignore_index) + + def test_multiclass_overflow(): """Test that multiclass computations does not overflow even on byte inputs.""" preds = torch.randint(20, (100,)).byte() diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 86c793f8c83..0b11200eb2f 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -325,6 +325,41 @@ def test_multiclass_stat_scores_dtype_gpu(self, inputs, dtype): ) +@pytest.mark.parametrize( + ("preds", "target", "ignore_index", "error_message"), + [ + ( + torch.randint(NUM_CLASSES + 1, (100,)), + torch.randint(NUM_CLASSES, (100,)), + None, + f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES}.*", + ), + ( + torch.randint(NUM_CLASSES, (100,)), + torch.randint(NUM_CLASSES + 1, (100,)), + None, + f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES}.*", + ), + ( + torch.randint(NUM_CLASSES + 2, (100,)), + torch.randint(NUM_CLASSES, (100,)), + 1, + f"Detected more unique values in `preds` than expected. Expected only {NUM_CLASSES + 1}.*", + ), + ( + torch.randint(NUM_CLASSES, (100,)), + torch.randint(NUM_CLASSES + 2, (100,)), + 1, + f"Detected more unique values in `target` than expected. Expected only {NUM_CLASSES + 1}.*", + ), + ], +) +def test_raises_error_on_too_many_classes(preds, target, ignore_index, error_message): + """Test that an error is raised if the number of classes in preds or target is larger than expected.""" + with pytest.raises(RuntimeError, match=error_message): + multiclass_stat_scores(preds, target, num_classes=NUM_CLASSES, ignore_index=ignore_index) + + _mc_k_target = torch.tensor([0, 1, 2]) _mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])