Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/ergas
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Apr 15, 2024
2 parents b4a59d5 + 6e088fe commit c580d19
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed bug in computation of `ERGAS` metric ([#2498](https://github.com/Lightning-AI/torchmetrics/pull/2498))


- Fixed `BootStrapper` wrapper not working with `kwargs` provided argument ([#2503](https://github.com/Lightning-AI/torchmetrics/pull/2503))


- Fixed warnings being suppressed in `MeanAveragePrecision` when requested ([#2501](https://github.com/Lightning-AI/torchmetrics/pull/2501))


Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/wrappers/bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def update(self, *args: Any, **kwargs: Any) -> None:
"""
args_sizes = apply_to_collection(args, Tensor, len)
kwargs_sizes = list(apply_to_collection(kwargs, Tensor, len))
kwargs_sizes = apply_to_collection(kwargs, Tensor, len)
if len(args_sizes) > 0:
size = args_sizes[0]
elif len(kwargs_sizes) > 0:
size = kwargs_sizes[0]
size = next(iter(kwargs_sizes.values()))
else:
raise ValueError("None of the input contained tensors, so could not determine the sampling size")

Expand Down
22 changes: 21 additions & 1 deletion tests/unittests/wrappers/test_bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sklearn.metrics import mean_squared_error, precision_score, recall_score
from torch import Tensor
from torchmetrics.classification import MulticlassF1Score, MulticlassPrecision, MulticlassRecall
from torchmetrics.regression import MeanSquaredError
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler

from unittests._helpers import seed_all
Expand Down Expand Up @@ -140,3 +140,23 @@ def test_low_sample_amount(sampling_strategy):
MulticlassF1Score(num_classes=3, average=None), num_bootstraps=20, sampling_strategy=sampling_strategy
)
assert bootstrap_f1(preds, target) # does not work


def test_args_and_kwargs_works():
"""Test that metric works with both args and kwargs and mix.
See issue: https://github.com/Lightning-AI/torchmetrics/issues/2450
"""
x = torch.rand(100)
y = x + torch.randn_like(x)
ae = MeanAbsoluteError()
assert ae(x, y) == ae(preds=x, target=y)

bootstrapped_ae = BootStrapper(ae)
res1 = bootstrapped_ae(x, y)
res2 = bootstrapped_ae(x, target=y)
res3 = bootstrapped_ae(preds=x, target=y)

assert (res1["mean"].shape == res2["mean"].shape) & (res2["mean"].shape == res3["mean"].shape)
assert (res1["std"].shape == res2["std"].shape) & (res2["mean"].shape == res3["std"].shape)

0 comments on commit c580d19

Please sign in to comment.