Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update _safe_divide to allow Accuracy to run on the GPU #2640

Merged
merged 20 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed BertScore calculation: pred target misalignment ([#2347](https://github.com/Lightning-AI/torchmetrics/pull/2347))


- Update `_safe_divide` to allow `Accuracy` to run on the GPU ([#2640](https://github.com/Lightning-AI/torchmetrics/pull/2640))


## [1.4.0] - 2024-05-03

### Added
Expand Down
1 change: 0 additions & 1 deletion src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens
"""
num = num if num.is_floating_point() else num.float()
denom = denom if denom.is_floating_point() else denom.float()
zero_division = torch.tensor(zero_division).float().to(num.device)
return torch.where(denom != 0, num / denom, zero_division)


Expand Down
71 changes: 71 additions & 0 deletions tests/unittests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,77 @@ def test_multiclass_accuracy_half_gpu(self, inputs, dtype):
dtype=dtype,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
@pytest.mark.parametrize("dtype", [torch.half, torch.double])
@pytest.mark.parametrize(
("average", "use_deterministic_algorithms"),
[
(None, True), # Defaults to "macro", but explicitly included for testing omission
# average=`macro` stays on GPU when `use_deterministic` is True. Otherwise syncs in `bincount`
("macro", True),
("micro", False),
("micro", True),
("weighted", True),
],
)
def test_multiclass_accuracy_gpu_sync_points(
self, inputs, dtype: torch.dtype, average: str, use_deterministic_algorithms: bool
):
"""Test GPU support of the metric, avoiding CPU sync points."""
preds, target = inputs

# Wrap the default functional to attach `sync_debug_mode` as `run_precision_test_gpu` handles moving data
# onto the GPU, so we cannot set the debug mode outside the call
def wrapped_multiclass_accuracy(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
) -> torch.Tensor:
prev_sync_debug_mode = torch.cuda.get_sync_debug_mode()
torch.cuda.set_sync_debug_mode("error")
try:
validate_args = False # `validate_args` will require CPU sync for exceptions
# average = average #'micro' # default is `macro` which uses a `_bincount` that does a CPU sync
torch.use_deterministic_algorithms(mode=use_deterministic_algorithms)
return multiclass_accuracy(preds, target, num_classes, validate_args=validate_args, average=average)
finally:
torch.cuda.set_sync_debug_mode(prev_sync_debug_mode)

self.run_precision_test_gpu(
preds=preds,
target=target,
metric_module=MulticlassAccuracy,
metric_functional=wrapped_multiclass_accuracy,
metric_args={"num_classes": NUM_CLASSES},
dtype=dtype,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
@pytest.mark.parametrize("dtype", [torch.half, torch.double])
@pytest.mark.parametrize(
("average", "use_deterministic_algorithms"),
[
# If you remove from this collection, please add items to `test_multiclass_accuracy_gpu_sync_points`
(None, False),
("macro", False),
("weighted", False),
],
)
def test_multiclass_accuracy_gpu_sync_points_uptodate(
self, inputs, dtype: torch.dtype, average: str, use_deterministic_algorithms: bool
):
"""Negative test for `test_multiclass_accuracy_gpu_sync_points`, to confirm completeness.

Tests that `test_multiclass_accuracy_gpu_sync_points` is kept up to date, explicitly validating that known
failures still fail, so that if they're fixed they must be added to
`test_multiclass_accuracy_gpu_sync_points`.

"""
with pytest.raises(RuntimeError, match="called a synchronizing CUDA operation"):
self.test_multiclass_accuracy_gpu_sync_points(
inputs=inputs, dtype=dtype, average=average, use_deterministic_algorithms=use_deterministic_algorithms
)


_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
Expand Down
Loading