Skip to content

Commit 2856e0b

Browse files
SkafteNickiBordamergify[bot]
authored
Fix broken clone method for classification metrics (#1250)
* fix clone method * chlog Co-authored-by: Jirka <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 5b3eb5c commit 2856e0b

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

CHANGELOG.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616

1717
### Changed
1818

19-
-
19+
-
2020

2121

2222
### Deprecated
2323

24-
-
24+
-
2525

2626

2727
### Removed
2828

29-
-
29+
-
3030

3131

3232
### Fixed
3333

34-
-
34+
- Fixed broken clone method for classification metrics ([#1250](https://github.com/Lightning-AI/metrics/pull/1250))
3535

3636

3737
## [0.10.0] - 2022-10-04

src/torchmetrics/metric.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from abc import ABC, abstractmethod
1717
from contextlib import contextmanager
1818
from copy import deepcopy
19-
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Union
19+
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union
2020

2121
import torch
2222
from torch import Tensor
@@ -848,6 +848,9 @@ def __pos__(self) -> "Metric":
848848
def __getitem__(self, idx: int) -> "Metric":
849849
return CompositionalMetric(lambda x: x[idx], self, None)
850850

851+
def __getnewargs__(self) -> Tuple:
852+
return (Metric.__str__(self),)
853+
851854

852855
def _neg(x: Tensor) -> Tensor:
853856
return -torch.abs(x)

tests/unittests/helpers/testers.py

+5
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,11 @@ def _class_test(
165165
if check_scriptable:
166166
torch.jit.script(metric)
167167

168+
# check that metric can be cloned
169+
clone = metric.clone()
170+
assert clone is not metric, "Clone is not a different object than the metric"
171+
assert type(clone) == type(metric), "Type of clone did not match metric type"
172+
168173
# move to device
169174
metric = metric.to(device)
170175
preds = apply_to_collection(preds, Tensor, lambda x: x.to(device))

0 commit comments

Comments
 (0)