From fac2df4bb2b20f3421dec04264c3ddb0486224dc Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 22 Mar 2022 10:01:50 +0100 Subject: [PATCH 1/9] update --- CHANGELOG.md | 3 +++ tests/classification/test_confusion_matrix.py | 21 +++++++++++++++++- .../classification/confusion_matrix.py | 22 ++++++++++++++++++- 3 files changed, 44 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e1c9439a1c1..07edb3fd8f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -124,6 +124,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed bug in MAP metric related to either no ground truth or no predictions ([#884](https://github.com/PyTorchLightning/metrics/pull/884)) +- Fixed `ConfusionMatrix` on GPU when running in deterministic mode ([]()) + + ## [0.7.2] - 2022-02-10 ### Fixed diff --git a/tests/classification/test_confusion_matrix.py b/tests/classification/test_confusion_matrix.py index 79b05f4a2f4..8d138bdc989 100644 --- a/tests/classification/test_confusion_matrix.py +++ b/tests/classification/test_confusion_matrix.py @@ -31,7 +31,7 @@ from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester from torchmetrics.classification.confusion_matrix import ConfusionMatrix -from torchmetrics.functional import confusion_matrix +from torchmetrics.functional.classification.confusion_matrix import confusion_matrix seed_all(42) @@ -186,3 +186,22 @@ def test_warning_on_nan(tmpdir): match=".* nan values found in confusion matrix have been replaced with zeros.", ): confusion_matrix(preds, target, num_classes=5, normalize="true") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu") +def test_bincount(): + """test that bincount works in deterministic setting on GPU.""" + torch.use_deterministic_algorithms(True) + + preds = torch.randint(3, size=(10,)) + target = torch.randint(3, size=(10,)) + # uses custom implementation + res1 = confusion_matrix(preds, target, num_classes=3) + + torch.use_deterministic_algorithms(False) + + # uses torch.bincount + res2 = confusion_matrix(preds, target, num_classes=3) + + # check for correctness + assert torch.allclose(res1, res2) diff --git a/torchmetrics/functional/classification/confusion_matrix.py b/torchmetrics/functional/classification/confusion_matrix.py index 60b5b0bfe05..f4be9e9a2b4 100644 --- a/torchmetrics/functional/classification/confusion_matrix.py +++ b/torchmetrics/functional/classification/confusion_matrix.py @@ -21,6 +21,26 @@ from torchmetrics.utilities.enums import DataType +def _bincount(x: Tensor, minlength: int): + """torch.bincount currently does not support deterministic mode on GPU. This implementation fallsback to a for- + loop counting occurences in that case. + + Args: + x: tensor to count + minlength: minimum length to count + + Returns: + Number of occurences for each unique element in x + """ + if x.is_cuda and torch.are_deterministic_algorithms_enabled(): + output = torch.zeros(minlength, device=x.device) + for i in range(minlength): + output[i] = (x == i).sum() + return output + else: + return torch.bincount(x, minlength=minlength) + + def _confusion_matrix_update( preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5, multilabel: bool = False ) -> Tensor: @@ -45,7 +65,7 @@ def _confusion_matrix_update( unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long) minlength = num_classes ** 2 - bins = torch.bincount(unique_mapping, minlength=minlength) + bins = _bincount(unique_mapping, minlength=minlength) if multilabel: confmat = bins.reshape(num_classes, 2, 2) else: From 33b206974121bc5e73889eec69bcfdbf715cf756 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 22 Mar 2022 10:05:02 +0100 Subject: [PATCH 2/9] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 07edb3fd8f7..0674c822007 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -124,7 +124,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed bug in MAP metric related to either no ground truth or no predictions ([#884](https://github.com/PyTorchLightning/metrics/pull/884)) -- Fixed `ConfusionMatrix` on GPU when running in deterministic mode ([]()) +- Fixed `ConfusionMatrix` on GPU when running in deterministic mode ([#900](https://github.com/PyTorchLightning/metrics/pull/900)) ## [0.7.2] - 2022-02-10 From d641921e3a3cc6393f30915fdebacee83e7c4bfd Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 22 Mar 2022 10:22:22 +0100 Subject: [PATCH 3/9] update all torch bincounts --- .../functional/classification/auroc.py | 3 ++- .../classification/average_precision.py | 3 ++- .../classification/confusion_matrix.py | 21 +------------------ torchmetrics/utilities/data.py | 20 ++++++++++++++++++ 4 files changed, 25 insertions(+), 22 deletions(-) diff --git a/torchmetrics/functional/classification/auroc.py b/torchmetrics/functional/classification/auroc.py index 259b9cc5ce0..dd5f0f20e44 100644 --- a/torchmetrics/functional/classification/auroc.py +++ b/torchmetrics/functional/classification/auroc.py @@ -20,6 +20,7 @@ from torchmetrics.functional.classification.auc import _auc_compute_without_check from torchmetrics.functional.classification.roc import roc from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.utilities.data import _bincount from torchmetrics.utilities.enums import AverageMethod, DataType from torchmetrics.utilities.imports import _TORCH_LOWER_1_6 @@ -166,7 +167,7 @@ def _auroc_compute( if mode == DataType.MULTILABEL: support = torch.sum(target, dim=0) else: - support = torch.bincount(target.flatten(), minlength=num_classes) + support = _bincount(target.flatten(), minlength=num_classes) return torch.sum(torch.stack(auc_scores) * support / support.sum()) allowed_average = (AverageMethod.NONE.value, AverageMethod.MACRO.value, AverageMethod.WEIGHTED.value) diff --git a/torchmetrics/functional/classification/average_precision.py b/torchmetrics/functional/classification/average_precision.py index 320ce9aa33a..b121636c50d 100644 --- a/torchmetrics/functional/classification/average_precision.py +++ b/torchmetrics/functional/classification/average_precision.py @@ -21,6 +21,7 @@ _precision_recall_curve_compute, _precision_recall_curve_update, ) +from torchmetrics.utilities.data import _bincount def _average_precision_update( @@ -102,7 +103,7 @@ def _average_precision_compute( if preds.ndim == target.ndim and target.ndim > 1: weights = target.sum(dim=0).float() else: - weights = torch.bincount(target, minlength=num_classes).float() + weights = _bincount(target, minlength=num_classes).float() weights = weights / torch.sum(weights) else: weights = None diff --git a/torchmetrics/functional/classification/confusion_matrix.py b/torchmetrics/functional/classification/confusion_matrix.py index f4be9e9a2b4..8f7770590e9 100644 --- a/torchmetrics/functional/classification/confusion_matrix.py +++ b/torchmetrics/functional/classification/confusion_matrix.py @@ -18,29 +18,10 @@ from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.utilities.data import _bincount from torchmetrics.utilities.enums import DataType -def _bincount(x: Tensor, minlength: int): - """torch.bincount currently does not support deterministic mode on GPU. This implementation fallsback to a for- - loop counting occurences in that case. - - Args: - x: tensor to count - minlength: minimum length to count - - Returns: - Number of occurences for each unique element in x - """ - if x.is_cuda and torch.are_deterministic_algorithms_enabled(): - output = torch.zeros(minlength, device=x.device) - for i in range(minlength): - output[i] = (x == i).sum() - return output - else: - return torch.bincount(x, minlength=minlength) - - def _confusion_matrix_update( preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5, multilabel: bool = False ) -> Tensor: diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index 8db95511c90..d1e110c6a3c 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -259,3 +259,23 @@ def _squeeze_scalar_element_tensor(x: Tensor) -> Tensor: def _squeeze_if_scalar(data: Any) -> Any: return apply_to_collection(data, Tensor, _squeeze_scalar_element_tensor) + + +def _bincount(x: Tensor, minlength: int): + """torch.bincount currently does not support deterministic mode on GPU. This implementation fallsback to a for- + loop counting occurences in that case. + + Args: + x: tensor to count + minlength: minimum length to count + + Returns: + Number of occurences for each unique element in x + """ + if x.is_cuda and torch.are_deterministic_algorithms_enabled(): + output = torch.zeros(minlength, device=x.device) + for i in range(minlength): + output[i] = (x == i).sum() + return output + else: + return torch.bincount(x, minlength=minlength) From 7e56c4475501528be2a66441270d40b236990bc3 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 22 Mar 2022 10:26:39 +0100 Subject: [PATCH 4/9] move test --- tests/classification/test_confusion_matrix.py | 19 ------------------ tests/test_utilities.py | 20 ++++++++++++++++++- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/tests/classification/test_confusion_matrix.py b/tests/classification/test_confusion_matrix.py index 8d138bdc989..08ce3e3fb4d 100644 --- a/tests/classification/test_confusion_matrix.py +++ b/tests/classification/test_confusion_matrix.py @@ -186,22 +186,3 @@ def test_warning_on_nan(tmpdir): match=".* nan values found in confusion matrix have been replaced with zeros.", ): confusion_matrix(preds, target, num_classes=5, normalize="true") - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu") -def test_bincount(): - """test that bincount works in deterministic setting on GPU.""" - torch.use_deterministic_algorithms(True) - - preds = torch.randint(3, size=(10,)) - target = torch.randint(3, size=(10,)) - # uses custom implementation - res1 = confusion_matrix(preds, target, num_classes=3) - - torch.use_deterministic_algorithms(False) - - # uses torch.bincount - res2 = confusion_matrix(preds, target, num_classes=3) - - # check for correctness - assert torch.allclose(res1, res2) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 9b989929331..0decaea4f29 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -16,7 +16,7 @@ from torch import tensor from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn -from torchmetrics.utilities.data import _flatten, _flatten_dict, get_num_classes, to_categorical, to_onehot +from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, get_num_classes, to_categorical, to_onehot from torchmetrics.utilities.distributed import class_reduce, reduce @@ -116,3 +116,21 @@ def test_flatten_dict(): inp = {"a": {"b": 1, "c": 2}, "d": 3} out = _flatten_dict(inp) assert out == {"b": 1, "c": 2, "d": 3} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu") +def test_bincount(): + """test that bincount works in deterministic setting on GPU.""" + torch.use_deterministic_algorithms(True) + + x = torch.randint(100, size=(100,)) + # uses custom implementation + res1 = _bincount(x, minlength=10) + + torch.use_deterministic_algorithms(False) + + # uses torch.bincount + res2 = _bincount(x, minlength=10) + + # check for correctness + assert torch.allclose(res1, res2) From 927980a7aaedbe72afb8c1c63c1e05ea83dfba8a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 22 Mar 2022 10:28:12 +0100 Subject: [PATCH 5/9] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0674c822007..625b2e8b739 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -124,7 +124,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed bug in MAP metric related to either no ground truth or no predictions ([#884](https://github.com/PyTorchLightning/metrics/pull/884)) -- Fixed `ConfusionMatrix` on GPU when running in deterministic mode ([#900](https://github.com/PyTorchLightning/metrics/pull/900)) +- Fixed `ConfusionMatrix`, `AUROC` and `AveragePrecision` on GPU when running in deterministic mode ([#900](https://github.com/PyTorchLightning/metrics/pull/900)) ## [0.7.2] - 2022-02-10 From 4cd9a103d7011b06ef86f44cf47a80a07325df27 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 22 Mar 2022 10:33:54 +0100 Subject: [PATCH 6/9] mypy --- torchmetrics/utilities/data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index d1e110c6a3c..faae46e9bc7 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -261,7 +261,7 @@ def _squeeze_if_scalar(data: Any) -> Any: return apply_to_collection(data, Tensor, _squeeze_scalar_element_tensor) -def _bincount(x: Tensor, minlength: int): +def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: """torch.bincount currently does not support deterministic mode on GPU. This implementation fallsback to a for- loop counting occurences in that case. @@ -273,6 +273,8 @@ def _bincount(x: Tensor, minlength: int): Number of occurences for each unique element in x """ if x.is_cuda and torch.are_deterministic_algorithms_enabled(): + if minlength is None: + minlength = len(torch.unique(x)) output = torch.zeros(minlength, device=x.device) for i in range(minlength): output[i] = (x == i).sum() From 8e3e3285b33b5e605a38d38ce91b6446e91939a3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 22 Mar 2022 12:43:17 +0100 Subject: [PATCH 7/9] Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- tests/test_utilities.py | 6 +++++- torchmetrics/utilities/data.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 0decaea4f29..cae8aa756a8 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -131,6 +131,10 @@ def test_bincount(): # uses torch.bincount res2 = _bincount(x, minlength=10) - + + # explicit call to make sure, that res2 is not by accident using our manual implementation + res3 = torch.bincount(x. minlength=10) + # check for correctness assert torch.allclose(res1, res2) + assert torch.allclose(res1, res3) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index faae46e9bc7..cb4c9c8a609 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -275,7 +275,7 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: if x.is_cuda and torch.are_deterministic_algorithms_enabled(): if minlength is None: minlength = len(torch.unique(x)) - output = torch.zeros(minlength, device=x.device) + output = torch.zeros(minlength, device=x.device, dtype=torch.long) for i in range(minlength): output[i] = (x == i).sum() return output From 1577e6caa46d08f8617363be2d727d8566e64c46 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Mar 2022 11:44:12 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_utilities.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index cae8aa756a8..33df1a5ee81 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -131,10 +131,10 @@ def test_bincount(): # uses torch.bincount res2 = _bincount(x, minlength=10) - + # explicit call to make sure, that res2 is not by accident using our manual implementation res3 = torch.bincount(x. minlength=10) - + # check for correctness assert torch.allclose(res1, res2) assert torch.allclose(res1, res3) From 0f200a5d40f35ac22b686eb33ffdab1695671b23 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 22 Mar 2022 12:46:04 +0100 Subject: [PATCH 9/9] fix flake --- tests/test_utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 33df1a5ee81..67d948dc7ca 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -133,7 +133,7 @@ def test_bincount(): res2 = _bincount(x, minlength=10) # explicit call to make sure, that res2 is not by accident using our manual implementation - res3 = torch.bincount(x. minlength=10) + res3 = torch.bincount(x, minlength=10) # check for correctness assert torch.allclose(res1, res2)