-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGenerator.py
65 lines (52 loc) · 2.12 KB
/
Generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import PIL
import keras
import numpy as np
from ModelConfig import *
class DataGenerator(keras.utils.Sequence):
"""Generates data for Keras"""
def __init__(self, folder, img_list, vgg, batch_size=32, dim=(32, 32, 3), shuffle=True):
'Initialization'
self.dim = dim
self.batch_size = batch_size
self.img_list = img_list
self.vgg = vgg
self.folder = folder
self.len_data = len(img_list)
self.shuffle = shuffle
self.on_epoch_end()
def __len__(self):
"""Denotes the number of batches per epoch"""
return int(np.floor(self.len_data / self.batch_size))
def __getitem__(self, index):
"""Generate one batch of data"""
# Generate indexes of the batch
indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size]
# Find list of IDs
img_temp = [self.img_list[k] for k in indexes]
# Generate data
X, B, F2, F5 = self.__data_generation(img_temp)
return X, [B, X, F2, F5]
def on_epoch_end(self):
"""Updates indexes after each epoch"""
self.indexes = np.arange(self.len_data)
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __data_generation(self, img_temp):
"""Generates data containing batch_size samples""" # X : (n_samples, *dim)
# Initialization
X = np.empty((self.batch_size, *self.dim))
B = np.empty((self.batch_size, self.dim[0] // 8, self.dim[1] // 8, 96))
# Generate data
for i in range(len(img_temp)):
# Store sample
img = PIL.Image.open(self.folder + "/" + img_temp[i])
img = img.resize(img_input_shape[0:2], PIL.Image.ANTIALIAS)
img = np.asarray(img)
X[i,] = img / 255
# B sert juste à avoir une coherence entre les sorties du réseau et les verites terrains
# il est rempli de 0
B[i,] = 0
# On génère maintenant les features pour la perceptual_loss
self.vgg._make_predict_function()
F2, F5 = self.vgg.predict(X)
return X, B, F2, F5