-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathstatics.py
76 lines (59 loc) · 2.41 KB
/
statics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
__all__ = ['AverageMeter', 'evaluator']
class AverageMeter(object):
r"""Computes and stores the average and current value
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
"""
def __init__(self, name):
self.reset()
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
self.name = name
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __repr__(self):
return f"==> For {self.name}: sum={self.sum}; avg={self.avg}"
def evaluator(sparse_pred, sparse_gt, raw_gt):
r""" Evaluation of decoding implemented in PyTorch Tensor
Computes normalized mean square error (NMSE) and rho.
"""
with torch.no_grad():
# Basic params
nt = 32
nc = 32
nc_expand = 257
# De-centralize
sparse_gt = sparse_gt - 0.5
sparse_pred = sparse_pred - 0.5
# Calculate the NMSE
power_gt = sparse_gt[:, 0, :, :] ** 2 + sparse_gt[:, 1, :, :] ** 2
difference = sparse_gt - sparse_pred
mse = difference[:, 0, :, :] ** 2 + difference[:, 1, :, :] ** 2
nmse = 10 * torch.log10((mse.sum(dim=[1, 2]) / power_gt.sum(dim=[1, 2])).mean())
# Calculate the Rho
n = sparse_pred.size(0)
sparse_pred = sparse_pred.permute(0, 2, 3, 1) # Move the real/imaginary dim to the last
zeros = sparse_pred.new_zeros((n, nt, nc_expand - nc, 2))
sparse_pred = torch.cat((sparse_pred, zeros), dim=2)
raw_pred = torch.fft(sparse_pred, signal_ndim=1)[:, :, :125, :]
norm_pred = raw_pred[..., 0] ** 2 + raw_pred[..., 1] ** 2
norm_pred = torch.sqrt(norm_pred.sum(dim=1))
norm_gt = raw_gt[..., 0] ** 2 + raw_gt[..., 1] ** 2
norm_gt = torch.sqrt(norm_gt.sum(dim=1))
real_cross = raw_pred[..., 0] * raw_gt[..., 0] + raw_pred[..., 1] * raw_gt[..., 1]
real_cross = real_cross.sum(dim=1)
imag_cross = raw_pred[..., 0] * raw_gt[..., 1] - raw_pred[..., 1] * raw_gt[..., 0]
imag_cross = imag_cross.sum(dim=1)
norm_cross = torch.sqrt(real_cross ** 2 + imag_cross ** 2)
rho = (norm_cross / (norm_pred * norm_gt)).mean()
return rho, nmse