Skip to content

Commit 0436885

Browse files
SkafteNickiBorda
authored andcommitted
Reduce memory usage for certain image metrics (#2089)
(cherry picked from commit 51439e6)
1 parent 5e9f1fd commit 0436885

File tree

3 files changed

+62
-30
lines changed

3 files changed

+62
-30
lines changed

CHANGELOG.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313

1414
-
1515

16-
1716
### Changed
1817

19-
-
18+
- Change default state of `SpectralAngleMapper` and `UniversalImageQualityIndex` to be tensors ([#2089](https://github.com/Lightning-AI/torchmetrics/pull/2089))
2019

2120

2221
### Removed

src/torchmetrics/image/sam.py

+31-14
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from typing import Any, List, Optional, Sequence, Union
1515

16-
from torch import Tensor
16+
from torch import Tensor, tensor
1717
from typing_extensions import Literal
1818

1919
from torchmetrics.functional.image.sam import _sam_compute, _sam_update
@@ -75,33 +75,50 @@ class SpectralAngleMapper(Metric):
7575

7676
preds: List[Tensor]
7777
target: List[Tensor]
78+
sum_sam: Tensor
79+
numel: Tensor
7880

7981
def __init__(
8082
self,
81-
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
83+
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
8284
**kwargs: Any,
8385
) -> None:
8486
super().__init__(**kwargs)
85-
rank_zero_warn(
86-
"Metric `SpectralAngleMapper` will save all targets and predictions in the buffer."
87-
" For large datasets, this may lead to a large memory footprint."
88-
)
89-
90-
self.add_state("preds", default=[], dist_reduce_fx="cat")
91-
self.add_state("target", default=[], dist_reduce_fx="cat")
87+
if reduction not in ("elementwise_mean", "sum", "none", None):
88+
raise ValueError(
89+
f"The `reduction` {reduction} is not valid. Valid options are `elementwise_mean`, `sum`, `none`, None."
90+
)
91+
if reduction == "none" or reduction is None:
92+
rank_zero_warn(
93+
"Metric `SpectralAngleMapper` will save all targets and predictions in the buffer when using"
94+
"`reduction=None` or `reduction='none'. For large datasets, this may lead to a large memory footprint."
95+
)
96+
self.add_state("preds", default=[], dist_reduce_fx="cat")
97+
self.add_state("target", default=[], dist_reduce_fx="cat")
98+
else:
99+
self.add_state("sum_sam", tensor(0.0), dist_reduce_fx="sum")
100+
self.add_state("numel", tensor(0), dist_reduce_fx="sum")
92101
self.reduction = reduction
93102

94103
def update(self, preds: Tensor, target: Tensor) -> None:
95104
"""Update state with predictions and targets."""
96105
preds, target = _sam_update(preds, target)
97-
self.preds.append(preds)
98-
self.target.append(target)
106+
if self.reduction == "none" or self.reduction is None:
107+
self.preds.append(preds)
108+
self.target.append(target)
109+
else:
110+
sam_score = _sam_compute(preds, target, reduction="sum")
111+
self.sum_sam += sam_score
112+
p_shape = preds.shape
113+
self.numel += p_shape[0] * p_shape[2] * p_shape[3]
99114

100115
def compute(self) -> Tensor:
101116
"""Compute spectra over state."""
102-
preds = dim_zero_cat(self.preds)
103-
target = dim_zero_cat(self.target)
104-
return _sam_compute(preds, target, self.reduction)
117+
if self.reduction == "none" or self.reduction is None:
118+
preds = dim_zero_cat(self.preds)
119+
target = dim_zero_cat(self.target)
120+
return _sam_compute(preds, target, self.reduction)
121+
return self.sum_sam / self.numel if self.reduction == "elementwise_mean" else self.sum_sam
105122

106123
def plot(
107124
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None

src/torchmetrics/image/uqi.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from typing import Any, List, Optional, Sequence, Union
1515

16-
from torch import Tensor
16+
from torch import Tensor, tensor
1717
from typing_extensions import Literal
1818

1919
from torchmetrics.functional.image.uqi import _uqi_compute, _uqi_update
@@ -73,6 +73,8 @@ class UniversalImageQualityIndex(Metric):
7373

7474
preds: List[Tensor]
7575
target: List[Tensor]
76+
sum_uqi: Tensor
77+
numel: Tensor
7678

7779
def __init__(
7880
self,
@@ -82,29 +84,43 @@ def __init__(
8284
**kwargs: Any,
8385
) -> None:
8486
super().__init__(**kwargs)
85-
rank_zero_warn(
86-
"Metric `UniversalImageQualityIndex` will save all targets and"
87-
" predictions in buffer. For large datasets this may lead"
88-
" to large memory footprint."
89-
)
90-
91-
self.add_state("preds", default=[], dist_reduce_fx="cat")
92-
self.add_state("target", default=[], dist_reduce_fx="cat")
87+
if reduction not in ("elementwise_mean", "sum", "none", None):
88+
raise ValueError(
89+
f"The `reduction` {reduction} is not valid. Valid options are `elementwise_mean`, `sum`, `none`, None."
90+
)
91+
if reduction is None or reduction == "none":
92+
rank_zero_warn(
93+
"Metric `UniversalImageQualityIndex` will save all targets and predictions in the buffer when using"
94+
"`reduction=None` or `reduction='none'. For large datasets, this may lead to a large memory footprint."
95+
)
96+
self.add_state("preds", default=[], dist_reduce_fx="cat")
97+
self.add_state("target", default=[], dist_reduce_fx="cat")
98+
else:
99+
self.add_state("sum_uqi", tensor(0.0), dist_reduce_fx="sum")
100+
self.add_state("numel", tensor(0), dist_reduce_fx="sum")
93101
self.kernel_size = kernel_size
94102
self.sigma = sigma
95103
self.reduction = reduction
96104

97105
def update(self, preds: Tensor, target: Tensor) -> None:
98106
"""Update state with predictions and targets."""
99107
preds, target = _uqi_update(preds, target)
100-
self.preds.append(preds)
101-
self.target.append(target)
108+
if self.reduction is None or self.reduction == "none":
109+
self.preds.append(preds)
110+
self.target.append(target)
111+
else:
112+
uqi_score = _uqi_compute(preds, target, self.kernel_size, self.sigma, reduction="sum")
113+
self.sum_uqi += uqi_score
114+
ps = preds.shape
115+
self.numel += ps[0] * ps[1] * (ps[2] - self.kernel_size[0] + 1) * (ps[3] - self.kernel_size[1] + 1)
102116

103117
def compute(self) -> Tensor:
104118
"""Compute explained variance over state."""
105-
preds = dim_zero_cat(self.preds)
106-
target = dim_zero_cat(self.target)
107-
return _uqi_compute(preds, target, self.kernel_size, self.sigma, self.reduction)
119+
if self.reduction == "none" or self.reduction is None:
120+
preds = dim_zero_cat(self.preds)
121+
target = dim_zero_cat(self.target)
122+
return _uqi_compute(preds, target, self.kernel_size, self.sigma, self.reduction)
123+
return self.sum_uqi / self.numel if self.reduction == "elementwise_mean" else self.sum_uqi
108124

109125
def plot(
110126
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None

0 commit comments

Comments
 (0)