diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a8dab4a8ba..c85b7b90023 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed BertScore calculation: pred target misalignment ([#2347](https://github.com/Lightning-AI/torchmetrics/pull/2347)) +- Update `_safe_divide` to allow `Accuracy` to run on the GPU ([#2640](https://github.com/Lightning-AI/torchmetrics/pull/2640)) + + ## [1.4.0] - 2024-05-03 ### Added diff --git a/requirements/_docs.txt b/requirements/_docs.txt index 7977c2d1274..8a1132ce44b 100644 --- a/requirements/_docs.txt +++ b/requirements/_docs.txt @@ -9,7 +9,7 @@ sphinx-autodoc-typehints ==1.23.0 sphinx-paramlinks ==0.6.0 sphinx-togglebutton ==0.3.2 sphinx-copybutton ==0.5.2 -sphinx-gallery ==0.16.0 +sphinx-gallery ==0.17.0 lightning >=1.8.0, <2.4.0 lightning-utilities ==0.11.6 diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 68cd344877d..ee11a36136f 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -56,8 +56,8 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens """ num = num if num.is_floating_point() else num.float() denom = denom if denom.is_floating_point() else denom.float() - zero_division = torch.tensor(zero_division).float().to(num.device) - return torch.where(denom != 0, num / denom, zero_division) + zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype).to(num.device, non_blocking=True) + return torch.where(denom != 0, num / denom, zero_division_tensor) def _adjust_weights_safe_divide( diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index db497cdb197..30d4a473a84 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -335,6 +335,77 @@ def test_multiclass_accuracy_half_gpu(self, inputs, dtype): dtype=dtype, ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + @pytest.mark.parametrize( + ("average", "use_deterministic_algorithms"), + [ + (None, True), # Defaults to "macro", but explicitly included for testing omission + # average=`macro` stays on GPU when `use_deterministic` is True. Otherwise syncs in `bincount` + ("macro", True), + ("micro", False), + ("micro", True), + ("weighted", True), + ], + ) + def test_multiclass_accuracy_gpu_sync_points( + self, inputs, dtype: torch.dtype, average: str, use_deterministic_algorithms: bool + ): + """Test GPU support of the metric, avoiding CPU sync points.""" + preds, target = inputs + + # Wrap the default functional to attach `sync_debug_mode` as `run_precision_test_gpu` handles moving data + # onto the GPU, so we cannot set the debug mode outside the call + def wrapped_multiclass_accuracy( + preds: torch.Tensor, + target: torch.Tensor, + num_classes: int, + ) -> torch.Tensor: + prev_sync_debug_mode = torch.cuda.get_sync_debug_mode() + torch.cuda.set_sync_debug_mode("error") + try: + validate_args = False # `validate_args` will require CPU sync for exceptions + # average = average #'micro' # default is `macro` which uses a `_bincount` that does a CPU sync + torch.use_deterministic_algorithms(mode=use_deterministic_algorithms) + return multiclass_accuracy(preds, target, num_classes, validate_args=validate_args, average=average) + finally: + torch.cuda.set_sync_debug_mode(prev_sync_debug_mode) + + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassAccuracy, + metric_functional=wrapped_multiclass_accuracy, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + @pytest.mark.parametrize( + ("average", "use_deterministic_algorithms"), + [ + # If you remove from this collection, please add items to `test_multiclass_accuracy_gpu_sync_points` + (None, False), + ("macro", False), + ("weighted", False), + ], + ) + def test_multiclass_accuracy_gpu_sync_points_uptodate( + self, inputs, dtype: torch.dtype, average: str, use_deterministic_algorithms: bool + ): + """Negative test for `test_multiclass_accuracy_gpu_sync_points`, to confirm completeness. + + Tests that `test_multiclass_accuracy_gpu_sync_points` is kept up to date, explicitly validating that known + failures still fail, so that if they're fixed they must be added to + `test_multiclass_accuracy_gpu_sync_points`. + + """ + with pytest.raises(RuntimeError, match="called a synchronizing CUDA operation"): + self.test_multiclass_accuracy_gpu_sync_points( + inputs=inputs, dtype=dtype, average=average, use_deterministic_algorithms=use_deterministic_algorithms + ) + _mc_k_target = torch.tensor([0, 1, 2]) _mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])