22
22
23
23
24
24
def _gaussian (kernel_size : int , sigma : float , dtype : torch .dtype , device : torch .device ) -> Tensor :
25
+ """
26
+ Computes 1D gaussian kernel
27
+
28
+ Args:
29
+ kernel_size: size of the gaussian kernel
30
+ sigma: Standard deviation of the gaussian kernel
31
+ dtype: data type of the output tensor
32
+ device: device of the output tensor
33
+
34
+ Example:
35
+ >>> _gaussian(3, 1, torch.float, 'cpu')
36
+ tensor([[0.2741, 0.4519, 0.2741]])
37
+ """
25
38
dist = torch .arange (start = (1 - kernel_size ) / 2 , end = (1 + kernel_size ) / 2 , step = 1 , dtype = dtype , device = device )
26
39
gauss = torch .exp (- torch .pow (dist / sigma , 2 ) / 2 )
27
40
return (gauss / gauss .sum ()).unsqueeze (dim = 0 ) # (1, kernel_size)
@@ -30,6 +43,25 @@ def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.
30
43
def _gaussian_kernel (
31
44
channel : int , kernel_size : Sequence [int ], sigma : Sequence [float ], dtype : torch .dtype , device : torch .device
32
45
) -> Tensor :
46
+ """
47
+ Computes 2D gaussian kernel
48
+
49
+ Args:
50
+ channel: number of channels in the image
51
+ kernel_size: size of the gaussian kernel as a tuple (h, w)
52
+ sigma: Standard deviation of the gaussian kernel
53
+ dtype: data type of the output tensor
54
+ device: device of the output tensor
55
+
56
+ Example:
57
+ >>> _gaussian_kernel(1, (5,5), (1,1), torch.float, "cpu")
58
+ tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030],
59
+ [0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
60
+ [0.0219, 0.0983, 0.1621, 0.0983, 0.0219],
61
+ [0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
62
+ [0.0030, 0.0133, 0.0219, 0.0133, 0.0030]]]])
63
+ """
64
+
33
65
gaussian_kernel_x = _gaussian (kernel_size [0 ], sigma [0 ], dtype , device )
34
66
gaussian_kernel_y = _gaussian (kernel_size [1 ], sigma [1 ], dtype , device )
35
67
kernel = torch .matmul (gaussian_kernel_x .t (), gaussian_kernel_y ) # (kernel_size, 1) * (1, kernel_size)
@@ -38,6 +70,15 @@ def _gaussian_kernel(
38
70
39
71
40
72
def _ssim_update (preds : Tensor , target : Tensor ) -> Tuple [Tensor , Tensor ]:
73
+ """
74
+ Updates and returns variables required to compute Structural Similarity Index Measure.
75
+ Checks for same shape and type of the input tensors.
76
+
77
+ Args:
78
+ preds: Predicted tensor
79
+ target: Ground truth tensor
80
+ """
81
+
41
82
if preds .dtype != target .dtype :
42
83
raise TypeError (
43
84
"Expected `preds` and `target` to have the same data type."
@@ -62,6 +103,31 @@ def _ssim_compute(
62
103
k1 : float = 0.01 ,
63
104
k2 : float = 0.03 ,
64
105
) -> Tensor :
106
+ """
107
+ Computes Structual Similarity Index Measure
108
+
109
+ Args:
110
+ preds: estimated image
111
+ target: ground truth image
112
+ kernel_size: size of the gaussian kernel (default: (11, 11))
113
+ sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5))
114
+ reduction: a method to reduce metric score over labels.
115
+
116
+ - ``'elementwise_mean'``: takes the mean (default)
117
+ - ``'sum'``: takes the sum
118
+ - ``'none'``: no reduction will be applied
119
+
120
+ data_range: Range of the image. If ``None``, it is determined from the image (max - min)
121
+ k1: Parameter of SSIM. Default: 0.01
122
+ k2: Parameter of SSIM. Default: 0.03
123
+
124
+ Example:
125
+ >>> preds = torch.rand([16, 1, 16, 16])
126
+ >>> target = preds * 0.75
127
+ >>> preds, target = _ssim_update(preds, target)
128
+ >>> _ssim_compute(preds, target)
129
+ tensor(0.9219)
130
+ """
65
131
if len (kernel_size ) != 2 or len (sigma ) != 2 :
66
132
raise ValueError (
67
133
"Expected `kernel_size` and `sigma` to have the length of two."
0 commit comments