-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_srcnn.py
125 lines (106 loc) · 5.81 KB
/
train_srcnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pathlib
# This script is used for training the SRCNN for improving the quality
# of the images generated by the GAN
# The input images are the upscaled (768x432) version of the generated images
# and the output images are also 768x432
# Defining preprocess_images, used to convert images from integers to floating points
def preprocess_images(dataset):
dataset = (dataset - 127.5) / 127.5
return dataset
# Defining process_images, uses the SRCNN to improve quality of the generator images
def process_images(dataset_lr_array, model, count, savepath):
images_srcnn = model.predict(dataset_lr_array)
for i in range(0, count):
tf.keras.preprocessing.image.save_img(savepath + str(i+1) + '.jpg', images_srcnn[i])
# Defining make_srcnn, which defines our SRCNN model
# Input is a 768x432 image (upscaled generator image), and output is a 768x432 image
def make_srcnn(in_shape=(432, 768, 3)):
layer1 = tf.keras.layers.Input(shape=(432, 768, 3))
layer2 = tf.keras.layers.Conv2D(6, kernel_size=(4, 4), padding='same', kernel_regularizer=tf.keras.regularizers.l1(10e-12))(layer1)
layer3 = tf.keras.layers.MaxPooling2D(padding='same', strides=(2, 2))(layer2)
layer4 = tf.keras.layers.Conv2D(6, kernel_size=(6, 6), padding='same', kernel_regularizer=tf.keras.regularizers.l1(10e-12))(layer3)
layer5 = tf.keras.layers.MaxPooling2D(padding='same', strides=(2, 2))(layer4)
layer6 = tf.keras.layers.Conv2D(6, kernel_size=(2, 2), padding='same', kernel_regularizer=tf.keras.regularizers.l1(10e-12))(layer5)
layer7 = tf.keras.layers.Flatten()(layer6)
layer8 = tf.keras.layers.Dense(60)(layer7)
layer9 = tf.keras.layers.Dense(186624)(layer8)
layer10 = tf.keras.layers.Reshape(target_shape=(108, 192, 9))(layer9)
layer11 = tf.keras.layers.Conv2D(6, kernel_size=(2, 2), padding='same', kernel_regularizer=tf.keras.regularizers.l1(10e-12))(layer10)
layer12 = tf.keras.layers.UpSampling2D(size=(2, 2))(layer11)
layer13 = tf.keras.layers.Conv2D(6, kernel_size=(6, 6), padding='same', kernel_regularizer=tf.keras.regularizers.l1(10e-12))(layer12)
layer14 = tf.keras.layers.Add()([layer13, layer4])
layer15 = tf.keras.layers.UpSampling2D(size=(2, 2))(layer14)
layer16 = tf.keras.layers.Conv2D(6, kernel_size=(4, 4), padding='same', kernel_regularizer=tf.keras.regularizers.l1(10e-12))(layer15)
layer17 = tf.keras.layers.Add()([layer16, layer2])
layer18 = tf.keras.layers.Conv2D(3, kernel_size=(1, 1), padding='same')(layer17)
layer19 = tf.keras.layers.Activation(tf.keras.activations.tanh)(layer18)
optimizer = tf.keras.optimizers.Adam(lr=0.0001, beta_1=0.95)
model = tf.keras.Model(inputs=layer1, outputs=layer19)
model.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['accuracy'])
model.summary()
return model
# Defining train, used for training the SRCNN model
def train(model, dataset_hr, dataset_lr, n_epochs=1000, batch_size=6):
training_size = n_epochs
steps_per_epoch = np.ceil(training_size / batch_size)
history = model.fit(x=dataset_lr, y=dataset_hr, epochs=n_epochs, steps_per_epoch=steps_per_epoch, verbose=2)
plot_training_metrics(history)
# Defining plot_training_metrics, which plots the loss and accuracy after training
def plot_training_metrics(history):
hist = history.history
# Now we plot out the training loss and training accuracy of the model during training
fig = plt.figure(figsize=(12, 5))
ax = fig.add_subplot(1, 2, 1)
ax.plot(hist['loss'], lw=3)
ax.set_title('Training loss', size=15)
ax.set_xlabel('Epoch', size=15)
ax.tick_params(axis='both', which='major', labelsize=15)
ax = fig.add_subplot(1, 2, 2)
ax.plot(hist['accuracy'], lw=3)
ax.set_title('Training accuracy', size=15)
ax.set_xlabel('Epoch', size=15)
ax.tick_params(axis='both', which='major', labelsize=15)
plt.show()
# Counting the number of images in the sampleimages_midres directory
# These images are used for training the SRCNN
count_mr = 0
for path in pathlib.Path("sampleimages_midres").iterdir():
if path.is_file():
count_mr += 1
print(count_mr)
# Counting the number of images in the sampleimages_lowres directory
# Also used for training the SRCNN
count_lr = 0
for path in pathlib.Path("sampleimages_lowres").iterdir():
if path.is_file():
count_lr += 1
print(count_lr)
# Loading in the images from the sampleimages_midres directory
training_images_mr = []
training_images_mr_list = []
for i in range(1, count_mr+1):
current_image_mr = tf.keras.preprocessing.image.load_img("sampleimages_midres/" + str(i) + ".jpg")
current_image_mr_array = tf.keras.preprocessing.image.img_to_array(current_image_mr)
training_images_mr.append(current_image_mr)
training_images_mr_list.append(current_image_mr_array)
training_images_mr_array = np.asarray(training_images_mr_list)
training_images_mr_array = preprocess_images(dataset=training_images_mr_array)
# Loading in the images from the sampleimages_lowres_upscaled directory
training_images_lr = []
training_images_lr_list = []
for i in range(1, count_lr+1):
current_image_lr = tf.keras.preprocessing.image.load_img("sampleimages_lowres_upscaled/" + str(i) + ".jpg")
current_image_lr_array = tf.keras.preprocessing.image.img_to_array(current_image_lr)
training_images_lr.append(current_image_lr)
training_images_lr_list.append(current_image_lr_array)
training_images_lr_array = np.asarray(training_images_lr_list)
training_images_lr_array = preprocess_images(dataset=training_images_lr_array)
# Making an instance of our SRCNN model
srcnn_model = make_srcnn()
# Training the SRCNN model
train(srcnn_model, dataset_hr=training_images_mr_array, dataset_lr=training_images_lr_array)
# Saving the SRCNN model after training
srcnn_model.save('models/srcnn_model/model1/')