Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend ImageGenerator to multilabel case #6128

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions keras/preprocessing/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,8 @@ def flow_from_directory(self, directory,
save_to_dir=None,
save_prefix='',
save_format='jpeg',
follow_links=False):
follow_links=False,
multilabel_classes=None):
return DirectoryIterator(
directory, self,
target_size=target_size, color_mode=color_mode,
Expand All @@ -471,7 +472,8 @@ def flow_from_directory(self, directory,
save_to_dir=save_to_dir,
save_prefix=save_prefix,
save_format=save_format,
follow_links=follow_links)
follow_links=follow_links,
multilabel_classes=multilabel_classes)

def standardize(self, x):
"""Apply the normalization configuration to a batch of inputs.
Expand Down Expand Up @@ -838,6 +840,7 @@ class DirectoryIterator(Iterator):
`"binary"`: binary targets (if there are only two classes),
`"categorical"`: categorical targets,
`"sparse"`: integer targets,
`"multilabel"`: multiple categorical targets,
`None`: no targets get yielded (only input images are yielded).
batch_size: Integer, size of a batch.
shuffle: Boolean, whether to shuffle the data between epochs.
Expand All @@ -851,6 +854,7 @@ class DirectoryIterator(Iterator):
images (if `save_to_dir` is set).
save_format: Format to use for saving sample images
(if `save_to_dir` is set).
multilabel_classes: Dictionary, mapping filepaths to numpy array of (multiple) classes (e.g. {'folder/some_image.jpg': np.array([ 0., 1., 0., 1.]),...})
"""

def __init__(self, directory, image_data_generator,
Expand All @@ -859,7 +863,7 @@ def __init__(self, directory, image_data_generator,
batch_size=32, shuffle=True, seed=None,
data_format=None,
save_to_dir=None, save_prefix='', save_format='jpeg',
follow_links=False):
follow_links=False, multilabel_classes=None):
if data_format is None:
data_format = K.image_data_format()
self.directory = directory
Expand All @@ -881,11 +885,12 @@ def __init__(self, directory, image_data_generator,
else:
self.image_shape = (1,) + self.target_size
self.classes = classes
if class_mode not in {'categorical', 'binary', 'sparse', None}:
if class_mode not in {'categorical', 'binary', 'sparse', 'multilabel', None}:
raise ValueError('Invalid class_mode:', class_mode,
'; expected one of "categorical", '
'"binary", "sparse", or None.')
'"binary", "sparse","multilabel" or None.')
self.class_mode = class_mode
self.multilabel_classes = multilabel_classes
self.save_to_dir = save_to_dir
self.save_prefix = save_prefix
self.save_format = save_format
Expand Down Expand Up @@ -921,7 +926,11 @@ def _recursive_list(subpath):

# second, build an index of the images in the different class subfolders
self.filenames = []
self.classes = np.zeros((self.samples,), dtype='int32')
# in multilabel case, we store multiple labels (binary format) instead of one single integer
if(self.class_mode == 'multilabel'):
self.classes = np.zeros((self.samples, self.num_class), dtype='int32')
else:
self.classes = np.zeros((self.samples,), dtype='int32')
i = 0
for subdir in classes:
subpath = os.path.join(directory, subdir)
Expand All @@ -933,7 +942,11 @@ def _recursive_list(subpath):
is_valid = True
break
if is_valid:
self.classes[i] = self.class_indices[subdir]
if(self.class_mode == 'multilabel'):
fileid = os.path.join(subdir, fname)
self.classes[i] = self.multilabel_classes[fileid]
else:
self.classes[i] = self.class_indices[subdir]
i += 1
# add filename relative to directory
absolute_path = os.path.join(root, fname)
Expand Down Expand Up @@ -980,6 +993,8 @@ def next(self):
batch_y = np.zeros((len(batch_x), self.num_class), dtype=K.floatx())
for i, label in enumerate(self.classes[index_array]):
batch_y[i, label] = 1.
elif self.class_mode == 'multilabel':
batch_y = self.classes[index_array]
else:
return batch_x
return batch_x, batch_y