Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed multiprocessing in generators. #7118

Closed
wants to merge 17 commits into from
Closed
16 changes: 11 additions & 5 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,7 @@ def generate_arrays_from_file(path):
'metrics': callback_metrics,
})
callbacks.on_train_begin()

print(do_validation, val_gen)
if do_validation and not val_gen:
if len(validation_data) == 2:
val_x, val_y = validation_data
Expand Down Expand Up @@ -1923,12 +1923,18 @@ def evaluate_generator(self, generator, steps,
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True`'
' and multiple workers may duplicate your data.'
' Please consider using the`keras.utils.Sequence'
' class.'))
UserWarning('Please consider using '
'the `keras.utils.Sequence` class.'))
enqueuer = None

# Reset Generator - necessary to release any locks potentially held
if hasattr(generator, "reset"):
generator.reset()
else:
warnings.warn(
UserWarning('Generator has no reset function, if using '
'multiprocessing then deadlock may occur.'))

try:
if is_sequence:
enqueuer = OrderedEnqueuer(generator,
Expand Down
104 changes: 82 additions & 22 deletions keras/preprocessing/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,23 @@ def fit(self, x,
self.principal_components = np.dot(np.dot(u, np.diag(1. / np.sqrt(s + self.zca_epsilon))), u.T)


class ValueStruct(object):
"""Abstract class to encapsulate `value` data.

Helps provide a common interface regardless if threading or processes
are used.

# Arguments:
val: Object, Data for which `.value` should return.

# Attributes:
value: Object, Holder for any data passed in.
"""

def __init__(self, val):
self.value = val


class Iterator(object):
"""Abstract base class for image data iterators.

Expand All @@ -687,39 +704,76 @@ class Iterator(object):
batch_size: Integer, size of a batch.
shuffle: Boolean, whether to shuffle the data between epochs.
seed: Random seeding for data shuffling.
use_multiprocessing: Boolean, if True - use process based threading.
"""

def __init__(self, n, batch_size, shuffle, seed):
def __init__(self, n, batch_size, shuffle, seed, use_multiprocessing):
self.n = n
self.batch_size = batch_size
self.shuffle = shuffle
self.batch_index = 0
self.total_batches_seen = 0
self.lock = threading.Lock()

# In multiprocessing we need to provide a shared value for the the
# lock to actually work, otherwise it updates individual processes.
if use_multiprocessing:
self.lock = multiprocessing.Lock()
self.index_array = multiprocessing.RawArray("i", np.arange(n))
self.current_index = multiprocessing.RawValue("i", 0)
self.current_batch_size = multiprocessing.RawValue("i", 0)
self.total_batches_seen = multiprocessing.RawValue("i", 0)
self.batch_index = multiprocessing.RawValue("i", 0)
else:
self.lock = threading.Lock()
# Emulate C-type shared variables from multiprocessing for
# threading.
self.index_array = np.arange(n)
self.current_index = ValueStruct(0)
self.current_batch_size = ValueStruct(0)
self.total_batches_seen = ValueStruct(0)
self.batch_index = ValueStruct(0)

self.index_generator = self._flow_index(n, batch_size, shuffle, seed)

def reset(self):
self.batch_index = 0
# Release any locks held by dead processes, only way to tell if lock
# is held.
if not self.lock.acquire(False):
self.lock.release()
else:
self.lock.release()
# Lock just in case, even though this is only called in single thread.
with self.lock:
self.batch_index.value = 0

def _flow_index(self, n, batch_size=32, shuffle=False, seed=None):
# Ensure self.batch_index is 0.
self.reset()
while 1:
if seed is not None:
np.random.seed(seed + self.total_batches_seen)
if self.batch_index == 0:
index_array = np.arange(n)
if shuffle:
index_array = np.random.permutation(n)
# Initialise values to save looking up, python optimised.
total_batches_seen = self.total_batches_seen.value
batch_index = self.batch_index.value
index_array = self.index_array

current_index = (self.batch_index * batch_size) % n
if seed is not None:
np.random.seed(seed + total_batches_seen)
if batch_index == 0 and shuffle:
perm = np.random.permutation(n)
# index_array is a ctype array in multi-processing,
# provide a common update.
for i in range(n):
index_array[i] = perm[i]

current_index = (batch_index * batch_size) % n
if n > current_index + batch_size:
current_batch_size = batch_size
self.batch_index += 1
batch_index += 1
else:
current_batch_size = n - current_index
self.batch_index = 0
self.total_batches_seen += 1
batch_index = 0

total_batches_seen += 1
# Update the shared variables again.
self.total_batches_seen.value = total_batches_seen
self.batch_index.value = batch_index
self.index_array = index_array

yield (index_array[current_index: current_index + current_batch_size],
current_index, current_batch_size)

Expand Down Expand Up @@ -752,12 +806,14 @@ class NumpyArrayIterator(Iterator):
images (if `save_to_dir` is set).
save_format: Format to use for saving sample images
(if `save_to_dir` is set).
use_multiprocessing: Boolean, if True - use process based threading.

"""

def __init__(self, x, y, image_data_generator,
batch_size=32, shuffle=False, seed=None,
data_format=None,
save_to_dir=None, save_prefix='', save_format='png'):
data_format=None, save_to_dir=None, save_prefix='',
save_format='png', use_multiprocessing=False):
if y is not None and len(x) != len(y):
raise ValueError('X (images tensor) and y (labels) '
'should have the same length. '
Expand Down Expand Up @@ -789,7 +845,8 @@ def __init__(self, x, y, image_data_generator,
self.save_to_dir = save_to_dir
self.save_prefix = save_prefix
self.save_format = save_format
super(NumpyArrayIterator, self).__init__(x.shape[0], batch_size, shuffle, seed)
super(NumpyArrayIterator, self).__init__(x.shape[0], batch_size,
shuffle, seed, use_multiprocessing)

def next(self):
"""For python 2.x.
Expand Down Expand Up @@ -925,6 +982,7 @@ class DirectoryIterator(Iterator):
images (if `save_to_dir` is set).
save_format: Format to use for saving sample images
(if `save_to_dir` is set).
use_multiprocessing: Boolean, if True - use process based threading.
"""

def __init__(self, directory, image_data_generator,
Expand All @@ -933,7 +991,8 @@ def __init__(self, directory, image_data_generator,
batch_size=32, shuffle=True, seed=None,
data_format=None,
save_to_dir=None, save_prefix='', save_format='png',
follow_links=False):
follow_links=False,
use_multiprocessing=False):
if data_format is None:
data_format = K.image_data_format()
self.directory = directory
Expand Down Expand Up @@ -1009,7 +1068,8 @@ def _recursive_list(subpath):
i += len(classes)
pool.close()
pool.join()
super(DirectoryIterator, self).__init__(self.samples, batch_size, shuffle, seed)
super(DirectoryIterator, self).__init__(self.samples, batch_size,
shuffle, seed, use_multiprocessing)

def next(self):
"""For python 2.x.
Expand Down
86 changes: 77 additions & 9 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,23 @@ def stop(self, timeout=None):
self.run_thread.join(timeout)


class ValueStruct(object):
"""Abstract class to encapsulate `value` data.

Helps provide a common interface regardless if threading or processes
are used.

# Arguments:
val: Object, Data for which `.value` should return.

# Attributes:
value: Object, Holder for any data passed in.
"""

def __init__(self, val):
self.value = val


class GeneratorEnqueuer(SequenceEnqueuer):
"""Builds a queue out of a data generator.

Expand All @@ -542,6 +559,9 @@ def __init__(self, generator,
self._threads = []
self._stop_event = None
self.queue = None
self.global_next_counter = None
self.global_put_counter = None
self.lock = None
self.random_seed = random_seed

def start(self, workers=1, max_queue_size=10):
Expand All @@ -552,13 +572,50 @@ def start(self, workers=1, max_queue_size=10):
max_queue_size: queue size
(when full, threads could block on `put()`)
"""

def data_generator_task():
def data_generator_task(init_counter, workers):
""" Calls the generators next function in different processes.

:param init_counter: Initialisation value of both internal counters
:param workers: The number of workers in use - use to increment
values correctly.
:return:
"""
next_counter = init_counter
put_counter = init_counter
while not self._stop_event.is_set():
try:
if self._use_multiprocessing or self.queue.qsize() < max_queue_size:
generator_output = next(self._generator)
self.queue.put(generator_output)
if self._use_multiprocessing or \
self.queue.qsize() < max_queue_size:
# Block until it is our turn to access the next method
next_block = True
while next_block:
with self.lock:
global_next_counter = \
self.global_next_counter.value
if global_next_counter == next_counter:
next_block = False
next_counter += workers
with self.lock:
self.global_next_counter.value += 1
generator_output = next(self._generator)
else:
time.sleep(self.wait_time)
# Block until it is our turn to place into the Q
block = True
while block:
with self.lock:
global_put_counter = \
self.global_put_counter.value
if put_counter == global_put_counter:
self.queue.put(generator_output)
block = False
# Update to the next batch
put_counter += workers
# Signal to other workers they can now put
with self.lock:
self.global_put_counter.value += 1
else:
time.sleep(self.wait_time)
else:
time.sleep(self.wait_time)
except Exception:
Expand All @@ -569,21 +626,30 @@ def data_generator_task():
if self._use_multiprocessing:
self.queue = multiprocessing.Queue(maxsize=max_queue_size)
self._stop_event = multiprocessing.Event()
self.lock = multiprocessing.Lock()
self.global_put_counter = multiprocessing.RawValue("i", 0)
self.global_next_counter = multiprocessing.RawValue("i", 0)
else:
self.queue = queue.Queue()
self._stop_event = threading.Event()
self.lock = threading.Lock()
self.global_put_counter = ValueStruct(0)
self.global_next_counter = ValueStruct(0)

for _ in range(workers):
for init_counter in range(workers):
if self._use_multiprocessing:
# Reset random seed else all children processes
# share the same seed
np.random.seed(self.random_seed)
thread = multiprocessing.Process(target=data_generator_task)
thread = multiprocessing.Process(
target=data_generator_task,
args=(init_counter, workers))
thread.daemon = True
if self.random_seed is not None:
self.random_seed += 1
else:
thread = threading.Thread(target=data_generator_task)
thread = threading.Thread(target=data_generator_task,
args=(init_counter, workers))
self._threads.append(thread)
thread.start()
except:
Expand Down Expand Up @@ -614,10 +680,12 @@ def stop(self, timeout=None):
if self._use_multiprocessing:
if self.queue is not None:
self.queue.close()

self._threads = []
self._stop_event = None
self.queue = None
self.lock = None
self.global_put_counter = None
self.global_next_counter = None

def get(self):
"""Creates a generator to extract data from the queue.
Expand Down
Loading