-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
116 lines (87 loc) · 3.55 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import os
import PIL.Image
import numpy as np
from keras.applications import VGG19
from Model import build_model, Model
from ModelConfig import img_input_shape
from Huffman import huffman_coding
def predict_from_ae(input_path, autoencoder, limit=10):
if os.path.isfile(input_path):
img_list = [input_path]
elif os.path.isdir(input_path):
img_list = [input_path + '/' + x for x in os.listdir(input_path)]
else:
raise Exception("input path does not exist")
if not os.path.exists("output"):
os.mkdir('output')
mse_list = []
psnr_list = []
size_list = []
dic_size = []
tx_list = []
for img_idx in range(min(limit, len(img_list))):
img = PIL.Image.open(img_list[img_idx])
img_img = img.resize(img_input_shape[0:2], PIL.Image.ANTIALIAS)
img = np.asarray(img_img) / 255
img = img.reshape(1, *img_input_shape)
reconstruction = autoencoder.predict(img)
codes = reconstruction[0]
mapping, original_size, compressed_size = huffman_coding(codes)
size_list += [compressed_size]
tx_list += [1 - compressed_size / original_size]
print(tx_list)
dic_size += [32 + len(code[1]) for code in mapping]
reconstruction = reconstruction[1] * 255
reconstruction = np.clip(reconstruction, 0, 255)
reconstruction = np.uint8(reconstruction)
reconstruction = reconstruction.reshape(*img_input_shape)
mse = np.mean((img * 255 - reconstruction) ** 2)
mse_list += [mse]
psnr = 10 * np.log(255 ** 2 / mse) / np.log(10)
psnr_list += [psnr]
print('img {} mse : {} psnr : {}'.format(img_list[img_idx], mse, psnr))
reconstruction_img = PIL.Image.fromarray(reconstruction)
filename = os.path.basename(img_list[img_idx]).split('.')[0]
img_img.save("output/" + filename + "_true.png")
reconstruction_img.save("output/" + filename + "_pred.png")
bpp = (np.sum(size_list) + np.sum(dic_size)) / (min(limit, len(img_list)) * np.product(img_input_shape))
psnr = np.mean(psnr_list)
mse = np.mean(mse_list)
print("bpp: {}, psnr: {}, mse: {}".format(bpp, psnr, mse))
def predict_from_weights(input_path, weight_path, limit=10):
# VGG for the perceptual loss
base_model = VGG19(weights="imagenet", include_top=False,
input_shape=img_input_shape)
perceptual_model = Model(inputs=base_model.input,
outputs=[base_model.get_layer("block2_pool").output,
base_model.get_layer("block5_pool").output],
name="VGG")
autoencoder, _ = build_model(perceptual_model)
if os.path.isfile(weight_path):
print("loading weights from {}".format(weight_path))
autoencoder.load_weights(weight_path)
else:
raise Exception("weight path does not exist")
predict_from_ae(input_path, autoencoder, limit)
if __name__ == "__main__":
argparser = argparse.ArgumentParser(
description='predict images')
argparser.add_argument(
'-w',
'--weight',
help='path to weight file')
argparser.add_argument(
'-i',
'--input',
help='path to input file or folder')
argparser.add_argument(
'-l',
'--limit',
help='maximum number of prediction',
default=10)
args = argparser.parse_args()
limit = int(args.limit)
input_path = args.input
weight_path = args.weight
predict_from_weights(input_path, weight_path, limit)