Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric bootstrapper #101

Merged
merged 45 commits into from
Mar 27, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
e94c4ba
add bootstrapping
SkafteNicki Mar 17, 2021
96967e4
tests
SkafteNicki Mar 17, 2021
93c30a8
pep8
SkafteNicki Mar 17, 2021
d785e6c
move args to init
SkafteNicki Mar 18, 2021
0d23b7c
fix tests
SkafteNicki Mar 18, 2021
7a08934
fix tests
SkafteNicki Mar 18, 2021
d1c0482
mypy
SkafteNicki Mar 18, 2021
25b38ad
remove pdb
SkafteNicki Mar 18, 2021
beac9e0
add bootstrapping
SkafteNicki Mar 17, 2021
807a4e2
tests
SkafteNicki Mar 17, 2021
7caabbe
pep8
SkafteNicki Mar 17, 2021
853614c
move args to init
SkafteNicki Mar 18, 2021
58980d6
fix tests
SkafteNicki Mar 18, 2021
86de867
fix tests
SkafteNicki Mar 18, 2021
191509b
mypy
SkafteNicki Mar 18, 2021
3bce9c1
remove pdb
SkafteNicki Mar 18, 2021
56fd2bc
versions
Borda Mar 23, 2021
e3c2a24
versions
Borda Mar 23, 2021
cf02eba
Update docs/source/references/modules.rst
SkafteNicki Mar 24, 2021
73cb3f6
merge
SkafteNicki Mar 24, 2021
9ef1b47
isort
SkafteNicki Mar 24, 2021
557ae6a
Apply suggestions from code review
Borda Mar 24, 2021
fa149f2
Update torchmetrics/wrappers/bootstrapping.py
SkafteNicki Mar 24, 2021
b022c5a
Merge branch 'master' into bootstrap
Borda Mar 24, 2021
c5364d0
update
Borda Mar 24, 2021
a3e9b40
update
Borda Mar 24, 2021
7b211ea
add poisson
SkafteNicki Mar 25, 2021
ca0f812
poisson
SkafteNicki Mar 25, 2021
e80f0ce
pep8
SkafteNicki Mar 25, 2021
e7a49f0
Merge branch 'master' into bootstrap
SkafteNicki Mar 25, 2021
cbf8a67
revert
SkafteNicki Mar 25, 2021
e5929ec
Merge branch 'bootstrap' of https://github.com/PyTorchLightning/metri…
SkafteNicki Mar 25, 2021
b0cb0d7
link
SkafteNicki Mar 25, 2021
e7922ac
isort
SkafteNicki Mar 25, 2021
6b7ebf8
roc changes remove
SkafteNicki Mar 25, 2021
ed825d5
fix
SkafteNicki Mar 25, 2021
1ad98af
Merge branch 'master' into bootstrap
SkafteNicki Mar 25, 2021
0f2ed2e
Merge branch 'master' into bootstrap
SkafteNicki Mar 26, 2021
944a131
fix tests
SkafteNicki Mar 26, 2021
cf527e9
pep8
SkafteNicki Mar 26, 2021
0987e8c
Apply suggestions from code review
Borda Mar 26, 2021
31bc041
Merge branch 'master' into bootstrap
SkafteNicki Mar 26, 2021
7997a03
suggestions
SkafteNicki Mar 26, 2021
634bb6a
Merge branch 'master' into bootstrap
Borda Mar 27, 2021
03024d8
pprint
Borda Mar 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114))


- Added `BootStrapper` to easely calculate confidence intervals for metrics ([#101](https://github.com/PyTorchLightning/metrics/pull/101))


### Changed

- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))
Expand Down
12 changes: 11 additions & 1 deletion docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,14 @@ R2Score
~~~~~~~

.. autoclass:: torchmetrics.R2Score
:noindex:
:noindex:

********
Wrappers
********

Modular wrapper metrics are not metrics in themself, but instead take a metric and alter the internal logic
of the base metric.

.. autoclass:: torchmetrics.BootStrapper
:noindex:
32 changes: 24 additions & 8 deletions tests/classification/test_roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,38 @@ def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1):
def _sk_roc_multilabel_prob(preds, target, num_classes=1):
sk_preds = preds.numpy()
sk_target = target.numpy()
return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True)
return _sk_roc_curve(
y_true=sk_target,
probas_pred=sk_preds,
num_classes=num_classes,
multilabel=True
)


def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1):
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True)
return _sk_roc_curve(
y_true=sk_target,
probas_pred=sk_preds,
num_classes=num_classes,
multilabel=True
)


@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes",
[(_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES),
(_input_mlb_prob.preds, _input_mlb_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES),
(_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_roc_multilabel_multidim_prob, NUM_CLASSES)]
"preds, target, sk_metric, num_classes", [
(_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES),
(_input_multilabel_prob.preds, _input_multilabel_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES),
(
_input_multilabel_multidim_prob.preds,
_input_multilabel_multidim_prob.target,
_sk_roc_multilabel_multidim_prob,
NUM_CLASSES
)
]
)
class TestROC(MetricTester):

Expand Down
Empty file added tests/wrappers/__init__.py
Empty file.
113 changes: 113 additions & 0 deletions tests/wrappers/test_bootstrapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import torch
from torch import Tensor
from sklearn.metrics import precision_score, recall_score

from torchmetrics.classification import Precision, Recall
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7
from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler

_preds = torch.randint(10, (10, 32))
_target = torch.randint(10, (10, 32))


class TestBootStrapper(BootStrapper):
""" For testing purpose, we subclass the bootstrapper class so we can get the exact permutation
the class is creating
"""
def update(self, *args) -> None:
self.out = []
for idx in range(self.num_bootstraps):
size = len(args[0])
sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy)
new_args = apply_to_collection(args, Tensor, torch.index_select, dim=0, index=sample_idx)
self.metrics[idx].update(*new_args)
self.out.append(new_args)

SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize("sampling_strategy", ['poisson', 'multinomial'])
def test_bootstrap_sampler(sampling_strategy):
""" make sure that the bootstrap sampler works as intended """
old_samples = torch.randn(10, 2)

# make sure that the new samples are only made up of old samples
idx = _bootstrap_sampler(10, sampling_strategy=sampling_strategy)
new_samples = old_samples[idx]
for ns in new_samples:
assert ns in old_samples

# make sure some samples are also sampled twice
found_one = False
for os in old_samples:
cond = os == new_samples
if cond.sum() > 2:
found_one = True
break

assert found_one, "resampling did not work because no samples were sampled twice"

# make sure some samples are never sampled
found_zero = False
for os in old_samples:
cond = os != new_samples
if cond.sum() > 0:
found_zero = True
break

assert found_zero, "resampling did not work because all samples were atleast sampled once"


@pytest.mark.parametrize("sampling_strategy", ['poisson', 'multinomial'])
@pytest.mark.parametrize(
"metric, sk_metric", [[Precision(average='micro'), precision_score], [Recall(average='micro'), recall_score]]
)
def test_bootstrap(sampling_strategy, metric, sk_metric):
""" Test that the different bootstraps gets updated as we expected and that the compute method works """
_kwargs = {'base_metric': metric, 'mean': True, 'std': True, 'raw': True, 'sampling_strategy': sampling_strategy}
if _TORCH_GREATER_EQUAL_1_7:
bootstrapper = TestBootStrapper(**_kwargs, quantile=torch.tensor([0.05, 0.95]))
else:
bootstrapper = TestBootStrapper(**_kwargs)
Borda marked this conversation as resolved.
Show resolved Hide resolved

collected_preds = [[] for _ in range(10)]
collected_target = [[] for _ in range(10)]
for p, t in zip(_preds, _target):
bootstrapper.update(p, t)

for i, o in enumerate(bootstrapper.out):

collected_preds[i].append(o[0])
collected_target[i].append(o[1])

collected_preds = [torch.cat(cp) for cp in collected_preds]
collected_target = [torch.cat(ct) for ct in collected_target]

sk_scores = [sk_metric(ct, cp, average='micro') for ct, cp in zip(collected_target, collected_preds)]

output = bootstrapper.compute()
# quantile only avaible for pytorch v1.7 and forward
if _TORCH_GREATER_EQUAL_1_7:
pl_mean, pl_std, pl_quantile, pl_raw = output
assert np.allclose(pl_quantile[0], np.quantile(sk_scores, 0.05))
assert np.allclose(pl_quantile[1], np.quantile(sk_scores, 0.95))
else:
pl_mean, pl_std, pl_raw = output

assert np.allclose(pl_mean, np.mean(sk_scores))
assert np.allclose(pl_std, np.std(sk_scores, ddof=1))
assert np.allclose(pl_raw, sk_scores)
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@
R2Score,
)
from torchmetrics.retrieval import RetrievalMAP # noqa: F401 E402
from torchmetrics.wrappers import BootStrapper # noqa: F401 E402
48 changes: 44 additions & 4 deletions torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,47 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Import utilities"""
import importlib
import operator
from distutils.version import LooseVersion

import torch
from pkg_resources import DistributionNotFound

_TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0")
_TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0")
_TORCH_LOWER_1_6 = LooseVersion(torch.__version__) < LooseVersion("1.6.0")

def _compare_version(package: str, op, version) -> bool:
"""
Compare package version with some requirements

>>> _compare_version("torch", operator.ge, "0.1")
True
"""
try:
pkg = importlib.import_module(package)
except (ModuleNotFoundError, DistributionNotFound):
return False
try:
pkg_version = LooseVersion(pkg.__version__)
except AttributeError:
return False
if not (hasattr(pkg_version, "vstring") and hasattr(pkg_version, "version")):
# this is mock by sphinx, so it shall return True ro generate all summaries
return True
return op(pkg_version, LooseVersion(version))


_TORCH_LOWER_1_4 = _compare_version("torch", operator.lt, "1.4.0")
_TORCH_LOWER_1_5 = _compare_version("torch", operator.lt, "1.5.0")
_TORCH_LOWER_1_6 = _compare_version("torch", operator.lt, "1.6.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
14 changes: 14 additions & 0 deletions torchmetrics/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401
Loading