Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binned PR-related metrics #128

Merged
merged 43 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
f45dd63
WIP: Binned PR-related metrics
maximsch2 Mar 24, 2021
4bb887d
attempt to fix types
maximsch2 Mar 25, 2021
c3a4174
switch to linspace to make old pytorch happy
maximsch2 Mar 26, 2021
df125ef
make flake happy
maximsch2 Mar 26, 2021
0fa6881
Merge branch 'master' of https://github.com/PyTorchLightning/metrics …
maximsch2 Mar 29, 2021
6ac4b34
clean up
maximsch2 Mar 29, 2021
8205ee3
Add more testing, move test input generation to the approproate place
maximsch2 Apr 7, 2021
cdadbae
Merge branch 'master' of https://github.com/PyTorchLightning/metrics …
maximsch2 Apr 7, 2021
eb70c49
bugfixes and more stable and thorough tests
maximsch2 Apr 7, 2021
15c07f2
flake8
maximsch2 Apr 7, 2021
e1bb5dc
Reuse python zip-based implementation as it can't be reproduced with …
maximsch2 Apr 7, 2021
c39384a
address comments
maximsch2 Apr 7, 2021
b6b289e
isort
maximsch2 Apr 7, 2021
a3c5dd2
Add docs and doctests, make APIs same as non-binned versions
maximsch2 Apr 8, 2021
6e568d9
pep8
maximsch2 Apr 8, 2021
d3a5d9f
isort
maximsch2 Apr 8, 2021
6d5b8b2
doctests likes longer title underlines :O
maximsch2 Apr 8, 2021
a1a7294
use numpy's nan_to_num
maximsch2 Apr 8, 2021
4e276be
add atol to bleu tests to make them more stable
maximsch2 Apr 8, 2021
9ce9745
atol=1e-2 for bleu
maximsch2 Apr 8, 2021
d19e52e
add more docs
maximsch2 Apr 8, 2021
704e7f6
pep8
maximsch2 Apr 8, 2021
e64d69a
Merge branch 'master' into binned_metrics
Borda Apr 9, 2021
00e2d84
remove nlp test hack
maximsch2 Apr 9, 2021
5bdb03f
Merge branch 'binned_metrics' of github.com:maximsch2/metrics into bi…
maximsch2 Apr 9, 2021
42fff5b
Merge branch 'master' into binned_metrics
SkafteNicki Apr 13, 2021
88f83d6
abc
Borda Apr 13, 2021
ee0d541
abc
Borda Apr 13, 2021
1dfea94
Merge branch 'master' of https://github.com/PyTorchLightning/metrics …
maximsch2 Apr 13, 2021
7e673c1
address comments
maximsch2 Apr 13, 2021
d6ac39d
Merge branch 'binned_metrics' of github.com:maximsch2/metrics into bi…
maximsch2 Apr 13, 2021
322e7e3
pep8
maximsch2 Apr 13, 2021
291f04b
abc
maximsch2 Apr 13, 2021
c1ae93e
format
Borda Apr 13, 2021
6218705
Merge branch 'binned_metrics' of https://github.com/maximsch2/metrics…
Borda Apr 13, 2021
72052f4
format
Borda Apr 13, 2021
82abbbf
format
Borda Apr 13, 2021
1b6acd6
format
Borda Apr 13, 2021
31c245a
flake8
maximsch2 Apr 13, 2021
1787d82
Merge branch 'binned_metrics' of github.com:maximsch2/metrics into bi…
maximsch2 Apr 13, 2021
2d54032
remove typecheck
maximsch2 Apr 13, 2021
469f205
chlog
Borda Apr 13, 2021
6c3ec24
Merge branch 'binned_metrics' of https://github.com/maximsch2/metrics…
Borda Apr 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions tests/classification/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,29 @@
preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)),
target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM))
)


# Generate plausible-looking inputs
def generate_plausible_inputs_multilabel():
maximsch2 marked this conversation as resolved.
Show resolved Hide resolved
correct_targets = torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE))
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)
targets = torch.zeros_like(preds, dtype=torch.long)
for i in range(preds.shape[0]):
for j in range(preds.shape[1]):
targets[i, j, correct_targets[i, j]] = 1
preds += torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES) * targets / 3

preds = preds / preds.sum(dim=2, keepdim=True)

return Input(preds=preds, target=targets)


def generate_plausible_inputs_binary():
targets = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))
preds = torch.rand(NUM_BATCHES, BATCH_SIZE) + torch.rand(NUM_BATCHES, BATCH_SIZE) * targets / 3
return Input(preds=preds, target=targets)


_input_multilabel_prob_plausible = generate_plausible_inputs_multilabel()

_input_binary_prob_plausible = generate_plausible_inputs_binary()
162 changes: 162 additions & 0 deletions tests/classification/test_binned_precision_recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from typing import Tuple

import numpy as np
import pytest
import torch
from sklearn.metrics import average_precision_score as _sk_average_precision_score
from sklearn.metrics import precision_recall_curve as _sk_precision_recall_curve

from tests.classification.inputs import (
_input_binary_prob,
_input_binary_prob_plausible,
_input_multilabel_prob,
_input_multilabel_prob_plausible,
maximsch2 marked this conversation as resolved.
Show resolved Hide resolved
)
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision, BinnedRecallAtFixedPrecision

seed_all(42)


def recall_at_precision_x_multilabel(
predictions: torch.Tensor, targets: torch.Tensor, min_precision: float
) -> Tuple[float, float]:
precision, recall, thresholds = _sk_precision_recall_curve(
targets, predictions,
)

try:
max_recall, max_precision, best_threshold = max(
(r, p, t)
for p, r, t in zip(precision, recall, thresholds)
if p >= min_precision
)
except ValueError:
max_recall, best_threshold = 0, 1e6

return float(max_recall), float(best_threshold)


def _multiclass_prob_sk_metric(predictions, targets, num_classes, min_precision):
max_recalls = torch.zeros(num_classes)
best_thresholds = torch.zeros(num_classes)

for i in range(num_classes):
max_recalls[i], best_thresholds[i] = recall_at_precision_x_multilabel(
predictions[:, i], targets[:, i], min_precision
)
return max_recalls, best_thresholds


def _binary_prob_sk_metric(predictions, targets, num_classes, min_precision):
return recall_at_precision_x_multilabel(
predictions, targets, min_precision
)


def _multiclass_average_precision_sk_metric(predictions, targets, num_classes):
return _sk_average_precision_score(targets, predictions, average=None)


@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes",
[
(_input_binary_prob.preds, _input_binary_prob.target, _binary_prob_sk_metric, 1),
(_input_binary_prob_plausible.preds, _input_binary_prob_plausible.target, _binary_prob_sk_metric, 1),
(
_input_multilabel_prob_plausible.preds,
_input_multilabel_prob_plausible.target,
_multiclass_prob_sk_metric,
NUM_CLASSES,
),
(
_input_multilabel_prob.preds,
_input_multilabel_prob.target,
_multiclass_prob_sk_metric,
NUM_CLASSES,
),
],
)
class TestBinnedRecallAtPrecision(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("min_precision", [0.1, 0.3, 0.5, 0.8])
def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, min_precision):
self.atol = 0.01
# rounding will simulate binning for both implementations
preds = torch.Tensor(np.round(preds.numpy(), 2)) + 1e-6
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=BinnedRecallAtFixedPrecision,
sk_metric=partial(sk_metric, num_classes=num_classes, min_precision=min_precision),
dist_sync_on_step=False,
check_dist_sync_on_step=False,
check_batch=False,
metric_args={
"num_classes": num_classes,
"min_precision": min_precision,
"num_thresholds": 101,
},
)


@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes",
[
(_input_binary_prob.preds, _input_binary_prob.target, _multiclass_average_precision_sk_metric, 1),
(
_input_binary_prob_plausible.preds,
_input_binary_prob_plausible.target,
_multiclass_average_precision_sk_metric,
1,
),
(
_input_multilabel_prob_plausible.preds,
_input_multilabel_prob_plausible.target,
_multiclass_average_precision_sk_metric,
NUM_CLASSES,
),
(
_input_multilabel_prob.preds,
_input_multilabel_prob.target,
_multiclass_average_precision_sk_metric,
NUM_CLASSES,
),
],
)
class TestBinnedAveragePrecision(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("num_thresholds", [200, 300])
def test_binned_pr(self, preds, target, sk_metric, num_classes, ddp, num_thresholds):
self.atol = 0.01
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=BinnedAveragePrecision,
sk_metric=partial(sk_metric, num_classes=num_classes),
dist_sync_on_step=False,
check_dist_sync_on_step=False,
check_batch=False,
metric_args={
"num_classes": num_classes,
"num_thresholds": num_thresholds,
},
)
146 changes: 146 additions & 0 deletions torchmetrics/classification/binned_precision_recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple, Union

import torch

from torchmetrics.functional.classification.average_precision import _average_precision_compute_with_precision_recall
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import METRIC_EPS, to_onehot


class BinnedPrecisionRecallCurve(Metric):
"""Returns a tensor of recalls for a fixed precision threshold.
It is a tensor instead of a single number, because it applies to multi-label inputs.
"""

TPs: torch.Tensor
FPs: torch.Tensor
FNs: torch.Tensor
thresholds: torch.Tensor

def __init__(
self,
num_classes: int,
num_thresholds: int = 100,
compute_on_step: bool = False, # will ignore this
**kwargs
):
assert not compute_on_step, "computation on each step is not supported"
super().__init__(compute_on_step=False, **kwargs)
self.num_classes = num_classes
self.num_thresholds = num_thresholds
thresholds = torch.linspace(0, 1.0 + METRIC_EPS, num_thresholds)
self.register_buffer("thresholds", thresholds)

for name in ("TPs", "FPs", "FNs"):
Borda marked this conversation as resolved.
Show resolved Hide resolved
self.add_state(
name=name,
default=torch.zeros(num_classes, num_thresholds),
dist_reduce_fx="sum",
)

def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None:
"""
Args
preds: (n_samples, n_classes) tensor
targets: (n_samples, n_classes) tensor
"""
# binary case
if len(preds.shape) == len(targets.shape) == 1:
preds = preds.reshape(-1, 1)
targets = targets.reshape(-1, 1)

if len(preds.shape) == len(targets.shape) + 1:
targets = to_onehot(targets, num_classes=self.num_classes)

targets = targets == 1
# Iterate one threshold at a time to conserve memory
for i in range(self.num_thresholds):
predictions = preds >= self.thresholds[i]
self.TPs[:, i] += (targets & predictions).sum(dim=0)
self.FPs[:, i] += ((~targets) & (predictions)).sum(dim=0)
self.FNs[:, i] += ((targets) & (~predictions)).sum(dim=0)

def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Returns float tensor of size n_classes"""
precisions = (self.TPs + METRIC_EPS) / (self.TPs + self.FPs + METRIC_EPS)
recalls = self.TPs / (self.TPs + self.FNs + METRIC_EPS)
# Need to guarantee that last precision=1 and recall=0
precisions = torch.cat([precisions, torch.ones(self.num_classes, 1,
dtype=precisions.dtype, device=precisions.device)], dim=1)
recalls = torch.cat([recalls, torch.zeros(self.num_classes, 1,
dtype=recalls.dtype, device=recalls.device)], dim=1)
thresholds = torch.cat([self.thresholds, torch.ones(1, dtype=recalls.dtype, device=recalls.device)], dim=0)
if self.num_classes == 1:
return (precisions[0, :], recalls[0, :], thresholds)
else:
return (precisions, recalls, thresholds)


class BinnedAveragePrecision(BinnedPrecisionRecallCurve):
def compute(self) -> Union[List[torch.Tensor], torch.Tensor]:
precisions, recalls, _ = super().compute()
return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes)


class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve):
def __init__(
self,
num_classes: int,
min_precision: float,
num_thresholds: int = 100,
compute_on_step: bool = False, # will ignore this
**kwargs
):
super().__init__(
num_classes=num_classes,
num_thresholds=num_thresholds,
compute_on_step=compute_on_step,
**kwargs
)
self.min_precision = min_precision

def compute(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns float tensor of size n_classes"""
precisions, recalls, thresholds = super().compute()
condition = precisions >= self.min_precision

if self.num_classes == 1:
recall_at_p, index = (
torch.where(
condition, recalls, torch.scalar_tensor(0.0, device=condition.device)
)
.max(dim=0)
)
if recall_at_p == 0.0:
return recall_at_p, torch.scalar_tensor(1e6, device=condition.device)
else:
return recall_at_p, thresholds[index]

recalls_at_p, indices = (
torch.where(
condition, recalls, torch.scalar_tensor(0.0, device=condition.device)
)
.max(dim=1)
)

thresholds_at_p = torch.zeros_like(recalls_at_p, device=condition.device, dtype=thresholds.dtype)
for i in range(self.num_classes):
if recalls_at_p[i] == 0.0:
thresholds_at_p[i] = 1e6
else:
thresholds_at_p[i] = thresholds[indices[i]]

return (recalls_at_p, thresholds_at_p)
8 changes: 8 additions & 0 deletions torchmetrics/functional/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def _average_precision_compute(
) -> Union[List[Tensor], Tensor]:
# todo: `sample_weights` is unused
precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
return _average_precision_compute_with_precision_recall(precision, recall, num_classes)


def _average_precision_compute_with_precision_recall(
precision: Tensor,
recall: Tensor,
num_classes: int,
) -> Union[List[Tensor], Tensor]:
# Return the step function integral
# The following works because the last entry of precision is
# guaranteed to be 1, as returned by precision_recall_curve
Expand Down