Skip to content

Commit 249266c

Browse files
authored
Merge branch 'master' into forward_and_compute_to_device
2 parents d09807c + d5aa720 commit 249266c

File tree

4 files changed

+128
-1
lines changed

4 files changed

+128
-1
lines changed

torchmetrics/functional/image/psnr.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,31 @@ def _psnr_compute(
2626
base: float = 10.0,
2727
reduction: str = 'elementwise_mean',
2828
) -> Tensor:
29+
"""
30+
Computes peak signal-to-noise ratio.
31+
32+
Args:
33+
sum_squared_error: Sum of square of errors over all observations
34+
n_obs: Number of predictions or observations
35+
data_range:
36+
the range of the data. If None, it is determined from the data (max - min). ``data_range`` must be given
37+
when ``dim`` is not None.
38+
base: a base of a logarithm to use (default: 10)
39+
reduction: a method to reduce metric score over labels.
40+
41+
- ``'elementwise_mean'``: takes the mean (default)
42+
- ``'sum'``: takes the sum
43+
- ``'none'``: no reduction will be applied
44+
45+
Example:
46+
>>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
47+
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
48+
>>> data_range = target.max() - target.min()
49+
>>> sum_squared_error, n_obs = _psnr_update(preds, target)
50+
>>> _psnr_compute(sum_squared_error, n_obs, data_range)
51+
tensor(2.5527)
52+
"""
53+
2954
psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs)
3055
psnr_vals = psnr_base_e * (10 / torch.log(tensor(base)))
3156
return reduce(psnr_vals, reduction=reduction)
@@ -36,6 +61,17 @@ def _psnr_update(
3661
target: Tensor,
3762
dim: Optional[Union[int, Tuple[int, ...]]] = None,
3863
) -> Tuple[Tensor, Tensor]:
64+
"""
65+
Updates and returns variables required to compute peak signal-to-noise ratio.
66+
67+
Args:
68+
preds: Predicted tensor
69+
target: Ground truth tensor
70+
dim:
71+
Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is
72+
None meaning scores will be reduced across all dimensions.
73+
"""
74+
3975
if dim is None:
4076
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
4177
n_obs = tensor(target.numel(), device=target.device)
@@ -66,7 +102,7 @@ def psnr(
66102
dim: Optional[Union[int, Tuple[int, ...]]] = None,
67103
) -> Tensor:
68104
"""
69-
Computes the peak signal-to-noise ratio
105+
Computes the peak signal-to-noise ratio.
70106
71107
Args:
72108
preds: estimated signal

torchmetrics/functional/image/ssim.py

+66
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,19 @@
2222

2323

2424
def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
25+
"""
26+
Computes 1D gaussian kernel
27+
28+
Args:
29+
kernel_size: size of the gaussian kernel
30+
sigma: Standard deviation of the gaussian kernel
31+
dtype: data type of the output tensor
32+
device: device of the output tensor
33+
34+
Example:
35+
>>> _gaussian(3, 1, torch.float, 'cpu')
36+
tensor([[0.2741, 0.4519, 0.2741]])
37+
"""
2538
dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device)
2639
gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2)
2740
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)
@@ -30,6 +43,25 @@ def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.
3043
def _gaussian_kernel(
3144
channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device
3245
) -> Tensor:
46+
"""
47+
Computes 2D gaussian kernel
48+
49+
Args:
50+
channel: number of channels in the image
51+
kernel_size: size of the gaussian kernel as a tuple (h, w)
52+
sigma: Standard deviation of the gaussian kernel
53+
dtype: data type of the output tensor
54+
device: device of the output tensor
55+
56+
Example:
57+
>>> _gaussian_kernel(1, (5,5), (1,1), torch.float, "cpu")
58+
tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030],
59+
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
60+
[0.0219, 0.0983, 0.1621, 0.0983, 0.0219],
61+
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
62+
[0.0030, 0.0133, 0.0219, 0.0133, 0.0030]]]])
63+
"""
64+
3365
gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
3466
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
3567
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)
@@ -38,6 +70,15 @@ def _gaussian_kernel(
3870

3971

4072
def _ssim_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
73+
"""
74+
Updates and returns variables required to compute Structural Similarity Index Measure.
75+
Checks for same shape and type of the input tensors.
76+
77+
Args:
78+
preds: Predicted tensor
79+
target: Ground truth tensor
80+
"""
81+
4182
if preds.dtype != target.dtype:
4283
raise TypeError(
4384
"Expected `preds` and `target` to have the same data type."
@@ -62,6 +103,31 @@ def _ssim_compute(
62103
k1: float = 0.01,
63104
k2: float = 0.03,
64105
) -> Tensor:
106+
"""
107+
Computes Structual Similarity Index Measure
108+
109+
Args:
110+
preds: estimated image
111+
target: ground truth image
112+
kernel_size: size of the gaussian kernel (default: (11, 11))
113+
sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5))
114+
reduction: a method to reduce metric score over labels.
115+
116+
- ``'elementwise_mean'``: takes the mean (default)
117+
- ``'sum'``: takes the sum
118+
- ``'none'``: no reduction will be applied
119+
120+
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
121+
k1: Parameter of SSIM. Default: 0.01
122+
k2: Parameter of SSIM. Default: 0.03
123+
124+
Example:
125+
>>> preds = torch.rand([16, 1, 16, 16])
126+
>>> target = preds * 0.75
127+
>>> preds, target = _ssim_update(preds, target)
128+
>>> _ssim_compute(preds, target)
129+
tensor(0.9219)
130+
"""
65131
if len(kernel_size) != 2 or len(sigma) != 2:
66132
raise ValueError(
67133
"Expected `kernel_size` and `sigma` to have the length of two."

torchmetrics/functional/retrieval/ndcg.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121

2222
def _dcg(target: Tensor) -> Tensor:
23+
""" Computes Discounted Cumulative Gain for input tensor """
2324
denom = torch.log2(torch.arange(target.shape[-1], device=target.device) + 2.0)
2425
return (target / denom).sum(dim=-1)
2526

torchmetrics/functional/text/bleu.py

+24
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,19 @@ def _bleu_score_update(
5454
ref_len: Tensor,
5555
n_gram: int = 4
5656
) -> Tuple[Tensor, Tensor]:
57+
"""
58+
Updates and returns variables required to compute the BLEU score.
59+
60+
Args:
61+
reference_corpus: An iterable of iterables of reference corpus
62+
translate_corpus: An iterable of machine translated corpus
63+
numerator: Numerator of precision score (true positives)
64+
denominator: Denominator of precision score (true positives + false positives)
65+
trans_len: count of words in a candidate translation
66+
ref_len: count of words in a reference translation
67+
n_gram: gram value ranged 1 to 4
68+
"""
69+
5770
for (translation, references) in zip(translate_corpus, reference_corpus):
5871
trans_len += len(translation)
5972
ref_len_list = [len(ref) for ref in references]
@@ -84,6 +97,17 @@ def _bleu_score_compute(
8497
n_gram: int = 4,
8598
smooth: bool = False
8699
) -> Tensor:
100+
"""
101+
Computes the BLEU score.
102+
103+
Args:
104+
trans_len: count of words in a candidate translation
105+
ref_len: count of words in a reference translation
106+
numerator: Numerator of precision score (true positives)
107+
denominator: Denominator of precision score (true positives + false positives)
108+
n_gram: gram value ranged 1 to 4
109+
smooth: Whether or not to apply smoothing
110+
"""
87111
if min(numerator) == 0.0:
88112
return tensor(0.0)
89113

0 commit comments

Comments
 (0)