Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flakyness in tests related to torch.unique with dim=None #2650

Merged
merged 10 commits into from
Sep 2, 2024
10 changes: 5 additions & 5 deletions src/torchmetrics/functional/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _binary_confusion_matrix_tensor_validation(
_check_same_shape(preds, target)

# Check that target only contains {0,1} values or value in ignore_index
unique_values = torch.unique(target)
unique_values = torch.unique(target, dim=None)
if ignore_index is None:
check = torch.any((unique_values != 0) & (unique_values != 1))
else:
Expand All @@ -107,7 +107,7 @@ def _binary_confusion_matrix_tensor_validation(

# If preds is label tensor, also check that it only contains {0,1} values
if not preds.is_floating_point():
unique_values = torch.unique(preds)
unique_values = torch.unique(preds, dim=None)
if torch.any((unique_values != 0) & (unique_values != 1)):
raise RuntimeError(
f"Detected the following values in `preds`: {unique_values} but expected only"
Expand Down Expand Up @@ -287,7 +287,7 @@ def _multiclass_confusion_matrix_tensor_validation(

check_value = num_classes if ignore_index is None else num_classes + 1
for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
num_unique_values = len(torch.unique(t))
num_unique_values = len(torch.unique(t, dim=None))
if num_unique_values > check_value:
raise RuntimeError(
f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
Expand Down Expand Up @@ -454,7 +454,7 @@ def _multilabel_confusion_matrix_tensor_validation(
)

# Check that target only contains [0,1] values or value in ignore_index
unique_values = torch.unique(target)
unique_values = torch.unique(target, dim=None)
if ignore_index is None:
check = torch.any((unique_values != 0) & (unique_values != 1))
else:
Expand All @@ -467,7 +467,7 @@ def _multilabel_confusion_matrix_tensor_validation(

# If preds is label tensor, also check that it only contains [0,1] values
if not preds.is_floating_point():
unique_values = torch.unique(preds)
unique_values = torch.unique(preds, dim=None)
if torch.any((unique_values != 0) & (unique_values != 1)):
raise RuntimeError(
f"Detected the following values in `preds`: {unique_values} but expected only"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _binary_precision_recall_curve_tensor_validation(
)

# Check that target only contains {0,1} values or value in ignore_index
unique_values = torch.unique(target)
unique_values = torch.unique(target, dim=None)
if ignore_index is None:
check = torch.any((unique_values != 0) & (unique_values != 1))
else:
Expand Down Expand Up @@ -417,7 +417,7 @@ def _multiclass_precision_recall_curve_tensor_validation(
f" but got {preds.shape} and {target.shape}"
)

num_unique_values = len(torch.unique(target))
num_unique_values = len(torch.unique(target, dim=None))
check = num_unique_values > num_classes if ignore_index is None else num_unique_values > num_classes + 1
if check:
raise RuntimeError(
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _binary_stat_scores_tensor_validation(
_check_same_shape(preds, target)

# Check that target only contains [0,1] values or value in ignore_index
unique_values = torch.unique(target)
unique_values = torch.unique(target, dim=None)
if ignore_index is None:
check = torch.any((unique_values != 0) & (unique_values != 1))
else:
Expand All @@ -80,7 +80,7 @@ def _binary_stat_scores_tensor_validation(

# If preds is label tensor, also check that it only contains [0,1] values
if not preds.is_floating_point():
unique_values = torch.unique(preds)
unique_values = torch.unique(preds, dim=None)
if torch.any((unique_values != 0) & (unique_values != 1)):
raise RuntimeError(
f"Detected the following values in `preds`: {unique_values} but expected only"
Expand Down Expand Up @@ -314,11 +314,11 @@ def _multiclass_stat_scores_tensor_validation(

check_value = num_classes if ignore_index is None else num_classes + 1
for t, name in ((target, "target"),) + ((preds, "preds"),) if not preds.is_floating_point() else (): # noqa: RUF005
num_unique_values = len(torch.unique(t))
num_unique_values = len(torch.unique(t, dim=None))
if num_unique_values > check_value:
raise RuntimeError(
f"Detected more unique values in `{name}` than expected. Expected only {check_value} but found"
f" {num_unique_values} in `target`."
f" {num_unique_values} in `{name}`. Found values: {torch.unique(t, dim=None)}."
)


Expand Down Expand Up @@ -624,7 +624,7 @@ def _multilabel_stat_scores_tensor_validation(
)

# Check that target only contains [0,1] values or value in ignore_index
unique_values = torch.unique(target)
unique_values = torch.unique(target, dim=None)
if ignore_index is None:
check = torch.any((unique_values != 0) & (unique_values != 1))
else:
Expand All @@ -637,7 +637,7 @@ def _multilabel_stat_scores_tensor_validation(

# If preds is label tensor, also check that it only contains [0,1] values
if not preds.is_floating_point():
unique_values = torch.unique(preds)
unique_values = torch.unique(preds, dim=None)
if torch.any((unique_values != 0) & (unique_values != 1)):
raise RuntimeError(
f"Detected the following values in `preds`: {unique_values} but expected only"
Expand Down
2 changes: 0 additions & 2 deletions tests/unittests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,6 @@ def test_multilabel_stat_scores_dtype_gpu(self, inputs, dtype):
)


# fixme: Expected only 5 but found 7 in `target`
@pytest.mark.flaky(reruns=5, only_rerun="RuntimeError")
def test_support_for_int():
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/1970."""
seed_all(42)
Expand Down
Loading