Skip to content

Commit

Permalink
Update _safe_divide to allow Accuracy to run on the GPU (#2640)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: jirka <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
(cherry picked from commit 56d3495)
  • Loading branch information
ndrwrbgs authored and Borda committed Aug 2, 2024
1 parent d214dbe commit 32279a6
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed BertScore calculation: pred target misalignment ([#2347](https://github.com/Lightning-AI/torchmetrics/pull/2347))


- Update `_safe_divide` to allow `Accuracy` to run on the GPU ([#2640](https://github.com/Lightning-AI/torchmetrics/pull/2640))


## [1.4.0] - 2024-05-03

### Added
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens
"""
num = num if num.is_floating_point() else num.float()
denom = denom if denom.is_floating_point() else denom.float()
zero_division = torch.tensor(zero_division).float().to(num.device)
return torch.where(denom != 0, num / denom, zero_division)
zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype).to(num.device, non_blocking=True)
return torch.where(denom != 0, num / denom, zero_division_tensor)


def _adjust_weights_safe_divide(
Expand Down
71 changes: 71 additions & 0 deletions tests/unittests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,77 @@ def test_multiclass_accuracy_half_gpu(self, inputs, dtype):
dtype=dtype,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
@pytest.mark.parametrize("dtype", [torch.half, torch.double])
@pytest.mark.parametrize(
("average", "use_deterministic_algorithms"),
[
(None, True), # Defaults to "macro", but explicitly included for testing omission
# average=`macro` stays on GPU when `use_deterministic` is True. Otherwise syncs in `bincount`
("macro", True),
("micro", False),
("micro", True),
("weighted", True),
],
)
def test_multiclass_accuracy_gpu_sync_points(
self, inputs, dtype: torch.dtype, average: str, use_deterministic_algorithms: bool
):
"""Test GPU support of the metric, avoiding CPU sync points."""
preds, target = inputs

# Wrap the default functional to attach `sync_debug_mode` as `run_precision_test_gpu` handles moving data
# onto the GPU, so we cannot set the debug mode outside the call
def wrapped_multiclass_accuracy(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
) -> torch.Tensor:
prev_sync_debug_mode = torch.cuda.get_sync_debug_mode()
torch.cuda.set_sync_debug_mode("error")
try:
validate_args = False # `validate_args` will require CPU sync for exceptions
# average = average #'micro' # default is `macro` which uses a `_bincount` that does a CPU sync
torch.use_deterministic_algorithms(mode=use_deterministic_algorithms)
return multiclass_accuracy(preds, target, num_classes, validate_args=validate_args, average=average)
finally:
torch.cuda.set_sync_debug_mode(prev_sync_debug_mode)

self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=MulticlassAccuracy,
metric_functional=wrapped_multiclass_accuracy,
metric_args={"num_classes": NUM_CLASSES},
dtype=dtype,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
@pytest.mark.parametrize("dtype", [torch.half, torch.double])
@pytest.mark.parametrize(
("average", "use_deterministic_algorithms"),
[
# If you remove from this collection, please add items to `test_multiclass_accuracy_gpu_sync_points`
(None, False),
("macro", False),
("weighted", False),
],
)
def test_multiclass_accuracy_gpu_sync_points_uptodate(
self, inputs, dtype: torch.dtype, average: str, use_deterministic_algorithms: bool
):
"""Negative test for `test_multiclass_accuracy_gpu_sync_points`, to confirm completeness.
Tests that `test_multiclass_accuracy_gpu_sync_points` is kept up to date, explicitly validating that known
failures still fail, so that if they're fixed they must be added to
`test_multiclass_accuracy_gpu_sync_points`.
"""
with pytest.raises(RuntimeError, match="called a synchronizing CUDA operation"):
self.test_multiclass_accuracy_gpu_sync_points(
inputs=inputs, dtype=dtype, average=average, use_deterministic_algorithms=use_deterministic_algorithms
)


_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
Expand Down

0 comments on commit 32279a6

Please sign in to comment.