Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix torch.bincount when running in deterministic mode and on GPU #900

Merged
merged 16 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug in MAP metric related to either no ground truth or no predictions ([#884](https://github.com/PyTorchLightning/metrics/pull/884))


- Fixed `ConfusionMatrix`, `AUROC` and `AveragePrecision` on GPU when running in deterministic mode ([#900](https://github.com/PyTorchLightning/metrics/pull/900))


## [0.7.2] - 2022-02-10

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion tests/classification/test_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics.classification.confusion_matrix import ConfusionMatrix
from torchmetrics.functional import confusion_matrix
from torchmetrics.functional.classification.confusion_matrix import confusion_matrix

seed_all(42)

Expand Down
20 changes: 19 additions & 1 deletion tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch import tensor

from torchmetrics.utilities import rank_zero_debug, rank_zero_info, rank_zero_warn
from torchmetrics.utilities.data import _flatten, _flatten_dict, get_num_classes, to_categorical, to_onehot
from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, get_num_classes, to_categorical, to_onehot
from torchmetrics.utilities.distributed import class_reduce, reduce


Expand Down Expand Up @@ -116,3 +116,21 @@ def test_flatten_dict():
inp = {"a": {"b": 1, "c": 2}, "d": 3}
out = _flatten_dict(inp)
assert out == {"b": 1, "c": 2, "d": 3}


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu")
def test_bincount():
"""test that bincount works in deterministic setting on GPU."""
torch.use_deterministic_algorithms(True)

x = torch.randint(100, size=(100,))
# uses custom implementation
res1 = _bincount(x, minlength=10)

torch.use_deterministic_algorithms(False)

# uses torch.bincount
res2 = _bincount(x, minlength=10)

# check for correctness
assert torch.allclose(res1, res2)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 2 additions & 1 deletion torchmetrics/functional/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchmetrics.functional.classification.auc import _auc_compute_without_check
from torchmetrics.functional.classification.roc import roc
from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.data import _bincount
from torchmetrics.utilities.enums import AverageMethod, DataType
from torchmetrics.utilities.imports import _TORCH_LOWER_1_6

Expand Down Expand Up @@ -166,7 +167,7 @@ def _auroc_compute(
if mode == DataType.MULTILABEL:
support = torch.sum(target, dim=0)
else:
support = torch.bincount(target.flatten(), minlength=num_classes)
support = _bincount(target.flatten(), minlength=num_classes)
return torch.sum(torch.stack(auc_scores) * support / support.sum())

allowed_average = (AverageMethod.NONE.value, AverageMethod.MACRO.value, AverageMethod.WEIGHTED.value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_precision_recall_curve_compute,
_precision_recall_curve_update,
)
from torchmetrics.utilities.data import _bincount


def _average_precision_update(
Expand Down Expand Up @@ -102,7 +103,7 @@ def _average_precision_compute(
if preds.ndim == target.ndim and target.ndim > 1:
weights = target.sum(dim=0).float()
else:
weights = torch.bincount(target, minlength=num_classes).float()
weights = _bincount(target, minlength=num_classes).float()
weights = weights / torch.sum(weights)
else:
weights = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.checks import _input_format_classification
from torchmetrics.utilities.data import _bincount
from torchmetrics.utilities.enums import DataType


Expand Down Expand Up @@ -45,7 +46,7 @@ def _confusion_matrix_update(
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
minlength = num_classes ** 2

bins = torch.bincount(unique_mapping, minlength=minlength)
bins = _bincount(unique_mapping, minlength=minlength)
if multilabel:
confmat = bins.reshape(num_classes, 2, 2)
else:
Expand Down
22 changes: 22 additions & 0 deletions torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,25 @@ def _squeeze_scalar_element_tensor(x: Tensor) -> Tensor:

def _squeeze_if_scalar(data: Any) -> Any:
return apply_to_collection(data, Tensor, _squeeze_scalar_element_tensor)


def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor:
"""torch.bincount currently does not support deterministic mode on GPU. This implementation fallsback to a for-
loop counting occurences in that case.

Args:
x: tensor to count
minlength: minimum length to count

Returns:
Number of occurences for each unique element in x
"""
if x.is_cuda and torch.are_deterministic_algorithms_enabled():
if minlength is None:
minlength = len(torch.unique(x))
output = torch.zeros(minlength, device=x.device)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
for i in range(minlength):
output[i] = (x == i).sum()
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
return output
else:
return torch.bincount(x, minlength=minlength)