-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
145 lines (123 loc) · 5.23 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import pickle
import shutil
import PIL.Image
import numpy as np
from keras.applications import VGG19
from keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping, LearningRateScheduler
from keras.models import Model
from keras.optimizers import Adam
from keras.utils import plot_model as keras_utils_plot_model
from CustomCallbacks import TensorBoardImage, HuffmanCallback, schedule
from CustomLoss import loss, perceptual_2, perceptual_5, entropy
from Generator import DataGenerator
from Model import build_model
from ModelConfig import img_input_shape, dataset_path, train_dir, validation_dir, test_dir, batch_size, epoch_nb
from predict import predict_from_ae
from Utils import generate_experiment
def train(autoencoder,
nb_epochs,
exp_path,
train_generator,
val_generator,
test_list,
batch_size,
extra_callbacks=[]):
# autoencoder must have been compiled
# Get last log
log_index = None
run_list = os.listdir("./logs")
if len(run_list) == 0:
log_index = 0
else:
indexes = [run[-1] for run in run_list]
log_index = str(int(max(indexes)) + 1)
# Tracking callbacks
tensorboard = TensorBoard(
log_dir='./logs/run' + str(log_index),
histogram_freq=0,
batch_size=batch_size)
tensorboard_image = TensorBoardImage(
"Reconstruction",
test_list=test_list,
logs_path='./logs/run' + str(log_index),
save_img=True,
exp_path=exp_path)
checkpoint = ModelCheckpoint(exp_path + "/weights.hdf5", save_best_only=True)
huffman = HuffmanCallback(val_generator[0][0])
history = autoencoder.fit_generator(train_generator,
epochs=nb_epochs,
validation_data=val_generator[0],
callbacks=[tensorboard_image,
tensorboard,
checkpoint,
huffman] + extra_callbacks)
# dumping history into pickle for further use
with open(exp_path + '/history', 'wb') as file_pi:
pickle.dump(history.history, file_pi)
shutil.copytree('./logs/run' + str(log_index), exp_path + '/run' + str(log_index))
return autoencoder
if __name__ == '__main__':
# On importe les données
train_list = os.listdir(dataset_path + "/" + train_dir)
val_list = os.listdir(dataset_path + "/" + validation_dir)
test_list = os.listdir(dataset_path + "/" + test_dir)
# On crée le dossier
exp_path = generate_experiment()
# Instanciate the VGG used for texture loss
base_model = VGG19(weights="imagenet", include_top=False,
input_shape=img_input_shape)
# Get the relevant layers
perceptual_model = Model(inputs=base_model.input,
outputs=[base_model.get_layer("block2_pool").output,
base_model.get_layer("block5_pool").output],
name="VGG")
# Freeze this model
perceptual_model.trainable = False
for layer in perceptual_model.layers:
layer.trainable = False
# Trick to force perceptual_model instanciation
img = PIL.Image.open(dataset_path + "/" + validation_dir + "/" + val_list[0])
img_img = img.resize(img_input_shape[0:2], PIL.Image.ANTIALIAS)
img = np.asarray(img_img) / 255
img = img.reshape(1, *img_input_shape)
perceptual_model.predict(img)
# Build the model (see Model.py)
autoencoder, _ = build_model(perceptual_model)
# Create generator for both train data
train_generator = DataGenerator(
dataset_path + "/" + train_dir, train_list, perceptual_model, batch_size, img_input_shape)
val_generator = DataGenerator(
dataset_path + "/" + validation_dir, val_list, perceptual_model, len(val_list), img_input_shape)
plot_model = False
if plot_model:
# Plot model graph
keras_utils_plot_model(autoencoder, to_file='autoencoder.png')
load_model = False
if load_model:
weight_path = "weights.hdf5"
print("loading weights from {}".format(weight_path))
autoencoder.load_weights(weight_path)
# Compile model with adam optimizer
optimizer = Adam(lr=1e-4, clipnorm=1)
autoencoder.compile(optimizer=optimizer, loss={"clipping_layer_1": loss,
"rounding_layer_1": entropy,
"VGG_block_2": perceptual_2,
"VGG_block_5": perceptual_5})
# extra callbacks
lr_decay = LearningRateScheduler(schedule)
early_stopping = EarlyStopping(
monitor='val_loss',
min_delta=1e-4,
patience=20,
verbose=1,
mode='auto')
autoencoder = train(autoencoder,
epoch_nb,
exp_path,
train_generator,
val_generator,
test_list,
batch_size,
[early_stopping, lr_decay])
predict_from_ae(dataset_path + "/" + validation_dir, autoencoder)