13
13
# limitations under the License.
14
14
from typing import Any , List , Optional , Sequence , Union
15
15
16
- from torch import Tensor
16
+ from torch import Tensor , tensor
17
17
from typing_extensions import Literal
18
18
19
19
from torchmetrics .functional .image .uqi import _uqi_compute , _uqi_update
@@ -73,6 +73,8 @@ class UniversalImageQualityIndex(Metric):
73
73
74
74
preds : List [Tensor ]
75
75
target : List [Tensor ]
76
+ sum_uqi : Tensor
77
+ numel : Tensor
76
78
77
79
def __init__ (
78
80
self ,
@@ -82,29 +84,43 @@ def __init__(
82
84
** kwargs : Any ,
83
85
) -> None :
84
86
super ().__init__ (** kwargs )
85
- rank_zero_warn (
86
- "Metric `UniversalImageQualityIndex` will save all targets and"
87
- " predictions in buffer. For large datasets this may lead"
88
- " to large memory footprint."
89
- )
90
-
91
- self .add_state ("preds" , default = [], dist_reduce_fx = "cat" )
92
- self .add_state ("target" , default = [], dist_reduce_fx = "cat" )
87
+ if reduction not in ("elementwise_mean" , "sum" , "none" , None ):
88
+ raise ValueError (
89
+ f"The `reduction` { reduction } is not valid. Valid options are `elementwise_mean`, `sum`, `none`, None."
90
+ )
91
+ if reduction is None or reduction == "none" :
92
+ rank_zero_warn (
93
+ "Metric `UniversalImageQualityIndex` will save all targets and predictions in the buffer when using"
94
+ "`reduction=None` or `reduction='none'. For large datasets, this may lead to a large memory footprint."
95
+ )
96
+ self .add_state ("preds" , default = [], dist_reduce_fx = "cat" )
97
+ self .add_state ("target" , default = [], dist_reduce_fx = "cat" )
98
+ else :
99
+ self .add_state ("sum_uqi" , tensor (0.0 ), dist_reduce_fx = "sum" )
100
+ self .add_state ("numel" , tensor (0 ), dist_reduce_fx = "sum" )
93
101
self .kernel_size = kernel_size
94
102
self .sigma = sigma
95
103
self .reduction = reduction
96
104
97
105
def update (self , preds : Tensor , target : Tensor ) -> None :
98
106
"""Update state with predictions and targets."""
99
107
preds , target = _uqi_update (preds , target )
100
- self .preds .append (preds )
101
- self .target .append (target )
108
+ if self .reduction is None or self .reduction == "none" :
109
+ self .preds .append (preds )
110
+ self .target .append (target )
111
+ else :
112
+ uqi_score = _uqi_compute (preds , target , self .kernel_size , self .sigma , reduction = "sum" )
113
+ self .sum_uqi += uqi_score
114
+ ps = preds .shape
115
+ self .numel += ps [0 ] * ps [1 ] * (ps [2 ] - self .kernel_size [0 ] + 1 ) * (ps [3 ] - self .kernel_size [1 ] + 1 )
102
116
103
117
def compute (self ) -> Tensor :
104
118
"""Compute explained variance over state."""
105
- preds = dim_zero_cat (self .preds )
106
- target = dim_zero_cat (self .target )
107
- return _uqi_compute (preds , target , self .kernel_size , self .sigma , self .reduction )
119
+ if self .reduction == "none" or self .reduction is None :
120
+ preds = dim_zero_cat (self .preds )
121
+ target = dim_zero_cat (self .target )
122
+ return _uqi_compute (preds , target , self .kernel_size , self .sigma , self .reduction )
123
+ return self .sum_uqi / self .numel if self .reduction == "elementwise_mean" else self .sum_uqi
108
124
109
125
def plot (
110
126
self , val : Optional [Union [Tensor , Sequence [Tensor ]]] = None , ax : Optional [_AX_TYPE ] = None
0 commit comments