-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmain.py
103 lines (79 loc) · 2.44 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import math
import numpy as np
import tensorflow as tf
from model import RFDNNet
from tensorflow import keras
from IPython.display import display
from tensorflow.keras import Input, Model
from tensorflow.keras.preprocessing import image_dataset_from_directory
from utils import *
dataset_url = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
data_dir = keras.utils.get_file(origin=dataset_url, fname="BSR", untar=True)
root_dir = os.path.join(data_dir, "BSDS500/data")
crop_size = 300
upscale_factor = 3
input_size = crop_size // upscale_factor
batch_size = 8
train_ds = image_dataset_from_directory(
root_dir,
batch_size=batch_size,
image_size=(crop_size, crop_size),
validation_split=0.2,
subset="training",
seed=1337,
label_mode=None,
)
valid_ds = image_dataset_from_directory(
root_dir,
batch_size=batch_size,
image_size=(crop_size, crop_size),
validation_split=0.2,
subset="validation",
seed=1337,
label_mode=None,
)
# Scale from (0, 255) to (0, 1)
train_ds = train_ds.map(scaling)
valid_ds = valid_ds.map(scaling)
dataset = os.path.join(root_dir, "images")
test_path = os.path.join(dataset, "test")
test_img_paths = sorted(
[
os.path.join(test_path, fname)
for fname in os.listdir(test_path)
if fname.endswith(".jpg")
]
)
train_ds = train_ds.map(
lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
)
train_ds = train_ds.prefetch(buffer_size=32)
valid_ds = valid_ds.map(
lambda x: (process_input(x, input_size, upscale_factor), process_target(x))
)
valid_ds = valid_ds.prefetch(buffer_size=32)
rfanet_x = RFDNNet()
x = Input(shape=(None, None, 3))
out = rfanet_x.main_model(x, upscale_factor)
model = Model(inputs=x, outputs=out)
model.summary()
early_stopping_callback = keras.callbacks.EarlyStopping(monitor="loss", patience=10)
checkpoint_filepath = "weights"
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath + '/best.h5',
monitor="loss",
mode="min",
save_best_only=True,
period=1
)
callbacks = [ESPCNCallback(test_img_paths), early_stopping_callback, model_checkpoint_callback]
loss_fn = keras.losses.MeanSquaredError()
optimizer = keras.optimizers.Adam(learning_rate=0.001)
epochs = 100
model.compile(
optimizer=optimizer, loss=loss_fn,
)
model.fit(
train_ds, epochs=epochs, callbacks=callbacks, validation_data=valid_ds, verbose=2
)