Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric bootstrapper #101

Merged
merged 45 commits into from
Mar 27, 2021
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
e94c4ba
add bootstrapping
SkafteNicki Mar 17, 2021
96967e4
tests
SkafteNicki Mar 17, 2021
93c30a8
pep8
SkafteNicki Mar 17, 2021
d785e6c
move args to init
SkafteNicki Mar 18, 2021
0d23b7c
fix tests
SkafteNicki Mar 18, 2021
7a08934
fix tests
SkafteNicki Mar 18, 2021
d1c0482
mypy
SkafteNicki Mar 18, 2021
25b38ad
remove pdb
SkafteNicki Mar 18, 2021
beac9e0
add bootstrapping
SkafteNicki Mar 17, 2021
807a4e2
tests
SkafteNicki Mar 17, 2021
7caabbe
pep8
SkafteNicki Mar 17, 2021
853614c
move args to init
SkafteNicki Mar 18, 2021
58980d6
fix tests
SkafteNicki Mar 18, 2021
86de867
fix tests
SkafteNicki Mar 18, 2021
191509b
mypy
SkafteNicki Mar 18, 2021
3bce9c1
remove pdb
SkafteNicki Mar 18, 2021
56fd2bc
versions
Borda Mar 23, 2021
e3c2a24
versions
Borda Mar 23, 2021
cf02eba
Update docs/source/references/modules.rst
SkafteNicki Mar 24, 2021
73cb3f6
merge
SkafteNicki Mar 24, 2021
9ef1b47
isort
SkafteNicki Mar 24, 2021
557ae6a
Apply suggestions from code review
Borda Mar 24, 2021
fa149f2
Update torchmetrics/wrappers/bootstrapping.py
SkafteNicki Mar 24, 2021
b022c5a
Merge branch 'master' into bootstrap
Borda Mar 24, 2021
c5364d0
update
Borda Mar 24, 2021
a3e9b40
update
Borda Mar 24, 2021
7b211ea
add poisson
SkafteNicki Mar 25, 2021
ca0f812
poisson
SkafteNicki Mar 25, 2021
e80f0ce
pep8
SkafteNicki Mar 25, 2021
e7a49f0
Merge branch 'master' into bootstrap
SkafteNicki Mar 25, 2021
cbf8a67
revert
SkafteNicki Mar 25, 2021
e5929ec
Merge branch 'bootstrap' of https://github.com/PyTorchLightning/metri…
SkafteNicki Mar 25, 2021
b0cb0d7
link
SkafteNicki Mar 25, 2021
e7922ac
isort
SkafteNicki Mar 25, 2021
6b7ebf8
roc changes remove
SkafteNicki Mar 25, 2021
ed825d5
fix
SkafteNicki Mar 25, 2021
1ad98af
Merge branch 'master' into bootstrap
SkafteNicki Mar 25, 2021
0f2ed2e
Merge branch 'master' into bootstrap
SkafteNicki Mar 26, 2021
944a131
fix tests
SkafteNicki Mar 26, 2021
cf527e9
pep8
SkafteNicki Mar 26, 2021
0987e8c
Apply suggestions from code review
Borda Mar 26, 2021
31bc041
Merge branch 'master' into bootstrap
SkafteNicki Mar 26, 2021
7997a03
suggestions
SkafteNicki Mar 26, 2021
634bb6a
Merge branch 'master' into bootstrap
Borda Mar 27, 2021
03024d8
pprint
Borda Mar 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added multilabel support to `ROC` metric ([#114](https://github.com/PyTorchLightning/metrics/pull/114))


- Added `BootStrapper` to easely calculate confidence intervals for metrics ([#101](https://github.com/PyTorchLightning/metrics/pull/101))


### Changed

- Changed `ExplainedVariance` from storing all preds/targets to tracking 5 statistics ([#68](https://github.com/PyTorchLightning/metrics/pull/68))
Expand Down
11 changes: 11 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,14 @@ RetrievalMRR

.. autoclass:: torchmetrics.RetrievalMRR
:noindex:


********
Wrappers
********

Modular wrapper metrics are not metrics in themself, but instead take a metric and alter the internal logic
of the base metric.

.. autoclass:: torchmetrics.BootStrapper
:noindex:
Empty file added tests/wrappers/__init__.py
Empty file.
113 changes: 113 additions & 0 deletions tests/wrappers/test_bootstrapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import torch
from sklearn.metrics import precision_score, recall_score
from torch import Tensor

from torchmetrics.classification import Precision, Recall
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7
from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler

_preds = torch.randint(10, (10, 32))
_target = torch.randint(10, (10, 32))


class TestBootStrapper(BootStrapper):
""" For testing purpose, we subclass the bootstrapper class so we can get the exact permutation
the class is creating
"""
def update(self, *args) -> None:
self.out = []
for idx in range(self.num_bootstraps):
size = len(args[0])
sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy)
new_args = apply_to_collection(args, Tensor, torch.index_select, dim=0, index=sample_idx)
self.metrics[idx].update(*new_args)
self.out.append(new_args)

SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.parametrize("sampling_strategy", ['poisson', 'multinomial'])
def test_bootstrap_sampler(sampling_strategy):
""" make sure that the bootstrap sampler works as intended """
old_samples = torch.randn(10, 2)

# make sure that the new samples are only made up of old samples
idx = _bootstrap_sampler(10, sampling_strategy=sampling_strategy)
new_samples = old_samples[idx]
for ns in new_samples:
assert ns in old_samples

# make sure some samples are also sampled twice
found_one = False
for os in old_samples:
cond = os == new_samples
if cond.sum() > 2:
found_one = True
break

assert found_one, "resampling did not work because no samples were sampled twice"

# make sure some samples are never sampled
found_zero = False
for os in old_samples:
cond = os != new_samples
if cond.sum() > 0:
found_zero = True
break

assert found_zero, "resampling did not work because all samples were atleast sampled once"


@pytest.mark.parametrize("sampling_strategy", ['poisson', 'multinomial'])
@pytest.mark.parametrize(
"metric, sk_metric", [[Precision(average='micro'), precision_score], [Recall(average='micro'), recall_score]]
)
def test_bootstrap(sampling_strategy, metric, sk_metric):
""" Test that the different bootstraps gets updated as we expected and that the compute method works """
_kwargs = {'base_metric': metric, 'mean': True, 'std': True, 'raw': True, 'sampling_strategy': sampling_strategy}
if _TORCH_GREATER_EQUAL_1_7:
bootstrapper = TestBootStrapper(**_kwargs, quantile=torch.tensor([0.05, 0.95]))
else:
bootstrapper = TestBootStrapper(**_kwargs)
Borda marked this conversation as resolved.
Show resolved Hide resolved

collected_preds = [[] for _ in range(10)]
collected_target = [[] for _ in range(10)]
for p, t in zip(_preds, _target):
bootstrapper.update(p, t)

for i, o in enumerate(bootstrapper.out):

collected_preds[i].append(o[0])
collected_target[i].append(o[1])

collected_preds = [torch.cat(cp) for cp in collected_preds]
collected_target = [torch.cat(ct) for ct in collected_target]

sk_scores = [sk_metric(ct, cp, average='micro') for ct, cp in zip(collected_target, collected_preds)]

output = bootstrapper.compute()
# quantile only avaible for pytorch v1.7 and forward
if _TORCH_GREATER_EQUAL_1_7:
pl_mean, pl_std, pl_quantile, pl_raw = output
assert np.allclose(pl_quantile[0], np.quantile(sk_scores, 0.05))
assert np.allclose(pl_quantile[1], np.quantile(sk_scores, 0.95))
else:
pl_mean, pl_std, pl_raw = output

assert np.allclose(pl_mean, np.mean(sk_scores))
assert np.allclose(pl_std, np.std(sk_scores, ddof=1))
assert np.allclose(pl_raw, sk_scores)
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@
R2Score,
)
from torchmetrics.retrieval import RetrievalMAP, RetrievalMRR # noqa: F401 E402
from torchmetrics.wrappers import BootStrapper # noqa: F401 E402
12 changes: 8 additions & 4 deletions torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Import utilities"""
import operator
from distutils.version import LooseVersion
from importlib import import_module
from importlib.util import find_spec

import torch
import torch # noqa: F401
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
from pkg_resources import DistributionNotFound


Expand Down Expand Up @@ -60,6 +62,8 @@ def _compare_version(package: str, op, version) -> bool:
return op(pkg_version, LooseVersion(version))


_TORCH_LOWER_1_4 = LooseVersion(torch.__version__) < LooseVersion("1.4.0")
_TORCH_LOWER_1_5 = LooseVersion(torch.__version__) < LooseVersion("1.5.0")
_TORCH_LOWER_1_6 = LooseVersion(torch.__version__) < LooseVersion("1.6.0")
_TORCH_LOWER_1_4 = _compare_version("torch", operator.lt, "1.4.0")
_TORCH_LOWER_1_5 = _compare_version("torch", operator.lt, "1.5.0")
_TORCH_LOWER_1_6 = _compare_version("torch", operator.lt, "1.6.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
14 changes: 14 additions & 0 deletions torchmetrics/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.wrappers.bootstrapping import BootStrapper # noqa: F401
174 changes: 174 additions & 0 deletions torchmetrics/wrappers/bootstrapping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Any, Callable, List, Optional, Union

import torch
from torch import Tensor, nn

from torchmetrics.metric import Metric
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7


def _bootstrap_sampler(
size: int,
sampling_strategy: str = 'poisson'
) -> Tensor:
""" Resample a tensor along its first dimension with replacement
Args:
size: number of samples
sampling_strategy: the strategy to use for sampling, either ``'poisson'`` or ``'multinomial'``
generator: a instance of ``torch.Generator`` that controls the sampling

Returns:
resampled tensor

"""
if sampling_strategy == 'poisson':
p = torch.distributions.Poisson(1)
n = p.sample((size,))
return torch.arange(size).repeat_interleave(n.long(), dim=0)
elif sampling_strategy == 'multinomial':
idx = torch.multinomial(
torch.ones(size),
num_samples=size,
replacement=True
)
return idx
raise ValueError('Unknown sampling strategy')


class BootStrapper(Metric):

def __init__(
self,
base_metric: Metric,
num_bootstraps: int = 10,
mean: bool = True,
std: bool = True,
quantile: Optional[Union[float, Tensor]] = None,
raw: bool = False,
sampling_strategy: str = 'poisson',
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None
) -> None:
r"""
Use to turn a metric into a `bootstrapped <https://en.wikipedia.org/wiki/Bootstrapping_(statistics)>`_
metric that can automate the process of getting confidence intervals for metric values. This wrapper
class basically keeps multiple copies of the same base metric in memory and whenever ``update`` or
``forward`` is called, all input tensors are resampled (with replacement) along the first dimension.

Args:
base_metric:
base metric class to wrap
num_bootstraps:
number of copies to make of the base metric for bootstrapping
mean:
if ``True`` return the mean of the bootstraps
std:
if ``True`` return the standard diviation of the bootstraps
quantile:
if given, returns the quantile of the bootstraps. Can only be used with
pytorch version 1.6 or higher
raw:
if ``True``, return all bootstrapped values
sampling_strategy:
Determines how to produce bootstrapped samplings. Either ``'poisson'`` or ``multinomial``.
If ``'possion'`` is chosen, the number of times each sample will be included in the bootstrap
will be given by :math:`n\sim Poisson(\lambda=1)`, which approximates the true bootstrap distribution
when the number of samples is large. If ``'multinomial'`` is chosen, we will apply true bootstrapping
at the batch level to approximate bootstrapping over the hole dataset.
compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
process_group:
Specify the process group on which synchronization is called.
default: ``None`` (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather.

Example::
>>> from torchmetrics import Accuracy, BootStrapper
>>> _ = torch.manual_seed(123)
>>> base_metric = Accuracy()
>>> bootstrap = BootStrapper(base_metric, num_bootstraps=20)
>>> bootstrap.update(torch.randint(5, (20,)), torch.randint(5, (20,)))
>>> output = bootstrap.compute()
>>> mean, std = output
>>> print(mean, std)
tensor(0.2205) tensor(0.0859)

"""
super().__init__(compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)
if not isinstance(base_metric, Metric):
raise ValueError(
"Expected base metric to be an instance of torchmetrics.Metric"
f" but received {base_metric}"
)

self.metrics = nn.ModuleList([deepcopy(base_metric) for _ in range(num_bootstraps)])
self.num_bootstraps = num_bootstraps

self.mean = mean
self.std = std
if quantile is not None and not _TORCH_GREATER_EQUAL_1_7:
raise ValueError('quantile argument can only be used with pytorch v1.7 or higher')
self.quantile = quantile
self.raw = raw

allowed_sampling = ('poisson', 'multinomial')
if sampling_strategy not in allowed_sampling:
raise ValueError(
f"Expected argument ``sampling_strategy`` to be one of {allowed_sampling}"
f" but recieved {sampling_strategy}"
)
self.sampling_strategy = sampling_strategy

def update(self, *args: Any, **kwargs: Any) -> None:
""" Updates the state of the base metric. Any tensor passed in will be bootstrapped along dimension 0 """
for idx in range(self.num_bootstraps):
args_sizes = apply_to_collection(args, Tensor, len)
kwargs_sizes = list(apply_to_collection(kwargs, Tensor, len))
if len(args_sizes) > 0:
size = args_sizes[0]
elif len(kwargs_sizes) > 0:
size = kwargs_sizes[0]
else:
raise ValueError('None of the input contained tensors, so could not determine the sampling size')
sample_idx = _bootstrap_sampler(size, sampling_strategy=self.sampling_strategy)
new_args = apply_to_collection(args, Tensor, torch.index_select, dim=0, index=sample_idx)
new_kwargs = apply_to_collection(kwargs, Tensor, torch.index_select, dim=0, index=sample_idx)
self.metrics[idx].update(*new_args, **new_kwargs)

def compute(self) -> List[Tensor]:
""" Computes the bootstrapped metric values. Allways returns a list of tensors, but the content of
the list depends on how the class was initialized
"""
computed_vals = torch.stack([m.compute() for m in self.metrics], dim=0)
output = []
if self.mean:
output.append(computed_vals.mean(dim=0))
if self.std:
output.append(computed_vals.std(dim=0))
if self.quantile is not None:
output.append(torch.quantile(computed_vals, self.quantile))
if self.raw:
output.append(computed_vals)
return output
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved