-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtrain.py
92 lines (73 loc) · 4.13 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import os
import random as rn
import tensorflow as tf
import time
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from keras.callbacks import ReduceLROnPlateau
from keras.callbacks import TensorBoard
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
import models
from utils import metrics
NO_OF_TRAINING_IMAGES = len(os.listdir('dataset/train/train_frames/image'))
NO_OF_VAL_IMAGES = len(os.listdir('dataset/train/val_frames/image'))
NO_OF_EPOCHS = 500
BATCH_SIZE = 8
IMAGE_SIZE = (256, 256)
SEED = 230
rn.seed(SEED)
np.random.seed(SEED)
tf.set_random_seed(SEED)
def main():
train_datagen = ImageDataGenerator(rescale=1. / 255)
train_image_generator = train_datagen.flow_from_directory('./dataset/train/train_frames',
target_size=IMAGE_SIZE,
class_mode=None,
batch_size=BATCH_SIZE,
color_mode='grayscale',
seed=SEED)
train_mask_generator = train_datagen.flow_from_directory('dataset/train/train_masks',
target_size=IMAGE_SIZE,
class_mode=None,
batch_size=BATCH_SIZE,
color_mode='grayscale',
seed=SEED)
val_datagen = ImageDataGenerator(rescale=1. / 255)
val_image_generator = val_datagen.flow_from_directory('dataset/train/val_frames',
target_size=IMAGE_SIZE,
class_mode=None,
batch_size=BATCH_SIZE,
color_mode='grayscale',
seed=SEED)
val_mask_generator = val_datagen.flow_from_directory('dataset/train/val_masks',
target_size=IMAGE_SIZE,
class_mode=None,
batch_size=BATCH_SIZE,
color_mode='grayscale',
seed=SEED)
train_generator = zip(train_image_generator, train_mask_generator)
val_generator = zip(val_image_generator, val_mask_generator)
# build model
model = models.UNET(input_size=(256, 256, 1))
# load pretrained
#model = load_model("model.h5", custom_objects={'mean_iou': metrics.mean_iou})
model.compile(optimizer=Adam(lr=1e-4), loss='binary_crossentropy', metrics=['accuracy', metrics.mean_iou])
# configure callbacks
checkpoint = ModelCheckpoint("model.h5", verbose=1, save_best_only=True, save_weights_only=False,
monitor='val_mean_iou', mode='max')
earlystopping = EarlyStopping(patience=10, verbose=1, monitor='val_mean_iou', mode='max')
reduce_lr = ReduceLROnPlateau(factor=0.2, patience=3, verbose=1, min_delta=0.000001,
monitor='val_mean_iou', mode='max')
tensorboard = TensorBoard(log_dir='./logs/' + time.strftime("%Y%m%d_%H%M%S"), histogram_freq=0,
write_graph=True, write_images=True)
# train model
model.fit_generator(train_generator, epochs=NO_OF_EPOCHS,
steps_per_epoch=(NO_OF_TRAINING_IMAGES // BATCH_SIZE),
validation_data=val_generator,
validation_steps=(NO_OF_VAL_IMAGES // BATCH_SIZE),
callbacks=[checkpoint, earlystopping, reduce_lr, tensorboard])
if __name__ == '__main__':
main()