Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ROC metric for CUDA tensors #2304

Merged
merged 9 commits into from
Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Tuple, Callable

import torch
from torch.nn import functional as F

from pytorch_lightning.metrics.functional.reduction import reduce
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -500,8 +501,7 @@ def _binary_clf_curve(
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0]
threshold_idxs = torch.cat([distinct_value_indices,
torch.tensor([target.size(0) - 1])])
threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1)

target = (target == pos_label).to(torch.long)
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
Expand Down
20 changes: 17 additions & 3 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,37 @@ def test_against_sklearn(sklearn_metric, torch_metric):
"""Compare PL metrics to sklearn version."""
pred = torch.randint(10, (500,))
target = torch.randint(10, (500,))

if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'

pred = pred.to(device)
target = target.to(device)

assert torch.allclose(
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
torch.tensor(sklearn_metric(target, pred), dtype=torch.float, device=device),
torch_metric(pred, target))

pred = torch.randint(10, (200,))
target = torch.randint(5, (200,))

pred = pred.to(device)
target = target.to(device)

assert torch.allclose(
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
torch.tensor(sklearn_metric(target, pred), dtype=torch.float, device=device),
torch_metric(pred, target))

pred = torch.randint(5, (200,))
target = torch.randint(10, (200,))

pred = pred.to(device)
target = target.to(device)

assert torch.allclose(
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
torch.tensor(sklearn_metric(target, pred), dtype=torch.float, device=device),
torch_metric(pred, target))


Expand Down