-
Notifications
You must be signed in to change notification settings - Fork 6
/
eval.py
104 lines (81 loc) · 3.36 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse, os
import torch
from torch.autograd import Variable
import numpy as np
import time, math, glob
import scipy.io as sio
from torch.backends import cudnn
from memnet1 import MemNet
from utils import convert_state_dict
torch.backends.cudnn.benchmark = True
cudnn.benchmark = True
parser = argparse.ArgumentParser(description="PyTorch MemNet Eval")
parser.add_argument("--cuda", action="store_true", help="use cuda?")
parser.add_argument("--model", default="checkpoint1/model_epoch_50.pth", type=str, help="model path")
parser.add_argument("--dataset", default="Set5", type=str, help="dataset name, Default: Set5")
parser.add_argument("--gpus", default="4", type=str, help="gpu ids (default: 0)")
def PSNR(pred, gt, shave_border=0):
height, width = pred.shape[:2]
pred = pred[shave_border:height - shave_border, shave_border:width - shave_border]
gt = gt[shave_border:height - shave_border, shave_border:width - shave_border]
imdff = pred - gt
rmse = math.sqrt(np.mean(imdff ** 2))
if rmse == 0:
return 100
return 20 * math.log10(255.0 / rmse)
opt = parser.parse_args()
cuda = opt.cuda
if cuda:
print("=> use gpu id: '{}'".format(opt.gpus))
os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
if not torch.cuda.is_available():
raise Exception("No GPU found or Wrong gpu id, please run without --cuda")
#model = torch.load(opt.model, map_location=lambda storage, loc: storage)["model"]
model = MemNet(1,64,6,6)
state = convert_state_dict( torch.load(opt.model)['model'])
model.load_state_dict(state)
if cuda:
model = model.cuda()
else:
model = model.cpu()
model.eval()
scales = [2,3,4]
#image_list = glob.glob(opt.dataset+"_mat/*.*")
image_list = glob.glob('data/SuperResolution/'+opt.dataset+"_mat/*.*")
for scale in scales:
avg_psnr_predicted = 0.0
avg_psnr_bicubic = 0.0
avg_elapsed_time = 0.0
count = 0.0
for image_name in image_list:
if str(scale) in image_name:
count += 1
print("Processing ", image_name)
im_gt_y = sio.loadmat(image_name)['im_gt_y']
im_b_y = sio.loadmat(image_name)['im_b_y']
im_gt_y = im_gt_y.astype(float)
im_b_y = im_b_y.astype(float)
psnr_bicubic = PSNR(im_gt_y, im_b_y,shave_border=scale)
avg_psnr_bicubic += psnr_bicubic
im_input = im_b_y/255.
#print(im_input.shape)
im_input = Variable((torch.from_numpy(im_input).float()).view(1, -1, im_input.shape[0], im_input.shape[1]), volatile= True)
if cuda:
im_input = im_input.cuda()
start_time = time.time()
HR = model(im_input)
elapsed_time = time.time() - start_time
avg_elapsed_time += elapsed_time
HR = HR.cpu()
im_h_y = HR.data[0].numpy().astype(np.float32)
im_h_y = im_h_y * 255.
im_h_y[im_h_y < 0] = 0
im_h_y[im_h_y > 255.] = 255.
im_h_y = im_h_y[0,:,:]
psnr_predicted = PSNR(im_gt_y, im_h_y,shave_border=scale)
avg_psnr_predicted += psnr_predicted
print("Scale=", scale)
print("Dataset=", opt.dataset)
print("PSNR_predicted=", avg_psnr_predicted/count)
print("PSNR_bicubic=", avg_psnr_bicubic/count)
print("It takes average {}s for processing".format(avg_elapsed_time/count))