Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the ordering bugs when using pickle_safe=True #6891

Merged
merged 29 commits into from
Jun 20, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 71 additions & 143 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

import warnings
import copy
import time
import numpy as np
import multiprocessing
import threading
import six

from keras.utils import Sequence
from keras.utils import GeneratorEnqueuer
from keras.utils import OrderedEnqueuer

try:
import queue
except ImportError:
Expand Down Expand Up @@ -579,101 +580,6 @@ def _standardize_weights(y, sample_weight=None, class_weight=None,
return np.ones((y.shape[0], y.shape[1]), dtype=K.floatx())


class GeneratorEnqueuer(object):
"""Builds a queue out of a data generator.

Used in `fit_generator`, `evaluate_generator`, `predict_generator`.

# Arguments
generator: a generator function which endlessly yields data
pickle_safe: use multiprocessing if True, otherwise threading
"""

def __init__(self, generator, pickle_safe=False):
self._generator = generator
self._pickle_safe = pickle_safe
self._threads = []
self._stop_event = None
self.queue = None

def start(self, workers=1, max_q_size=10, wait_time=0.05):
"""Kicks off threads which add data from the generator into the queue.

# Arguments
workers: number of worker threads
max_q_size: queue size (when full, threads could block on put())
wait_time: time to sleep in-between calls to put()
"""

def data_generator_task():
while not self._stop_event.is_set():
try:
if self._pickle_safe or self.queue.qsize() < max_q_size:
generator_output = next(self._generator)
self.queue.put(generator_output)
else:
time.sleep(wait_time)
except Exception:
self._stop_event.set()
raise

try:
if self._pickle_safe:
self.queue = multiprocessing.Queue(maxsize=max_q_size)
self._stop_event = multiprocessing.Event()
if hasattr(data_generator_task, 'lock'):
# We should replace the threading lock of the iterator
# with a process-safe lock.
data_generator_task.lock = multiprocessing.Lock()
else:
self.queue = queue.Queue()
self._stop_event = threading.Event()

for _ in range(workers):
if self._pickle_safe:
# Reset random seed else all children processes
# share the same seed
np.random.seed()
thread = multiprocessing.Process(target=data_generator_task)
thread.daemon = True
else:
thread = threading.Thread(target=data_generator_task)
self._threads.append(thread)
thread.start()
except:
self.stop()
raise

def is_running(self):
return self._stop_event is not None and not self._stop_event.is_set()

def stop(self, timeout=None):
"""Stop running threads and wait for them to exit, if necessary.

Should be called by the same thread which called start().

# Arguments
timeout: maximum time to wait on thread.join()
"""
if self.is_running():
self._stop_event.set()

for thread in self._threads:
if thread.is_alive():
if self._pickle_safe:
thread.terminate()
else:
thread.join(timeout)

if self._pickle_safe:
if self.queue is not None:
self.queue.close()

self._threads = []
self._stop_event = None
self.queue = None


class Model(Container):
"""The `Model` class adds training & evaluation routines to a `Container`.
"""
Expand Down Expand Up @@ -1720,18 +1626,24 @@ def fit_generator(self, generator,
validation_data=None,
validation_steps=None,
class_weight=None,
max_q_size=10,
max_queue_size=10,
workers=1,
pickle_safe=False,
use_multiprocessing=False,
initial_epoch=0):
"""Fits the model on data yielded batch-by-batch by a Python generator.

The generator is run in parallel to the model, for efficiency.
For instance, this allows you to do real-time data augmentation
on images on CPU in parallel to training your model on GPU.

The use of `keras.utils.Sequence` guarantees the ordering
and guarantees the single use of every input per epoch when
using `use_multiprocessing=True`.

# Arguments
generator: a generator.
generator: a generator or an instance of Sequence (keras.utils.Sequence)
object in order to avoid duplicate data
when using multiprocessing.
The output of the generator must be either
- a tuple (inputs, targets)
- a tuple (inputs, targets, sample_weights).
Expand All @@ -1756,10 +1668,10 @@ def fit_generator(self, generator,
to yield from `generator` before stopping.
class_weight: dictionary mapping class indices to a weight
for the class.
max_q_size: maximum size for the generator queue
max_queue_size: maximum size for the generator queue
workers: maximum number of processes to spin up
when using process based threading
pickle_safe: if True, use process based threading.
use_multiprocessing: if True, use process based threading.
Note that because
this implementation relies on multiprocessing,
you should not pass
Expand Down Expand Up @@ -1804,7 +1716,8 @@ def generate_arrays_from_file(path):
# python 2 has 'next', 3 has '__next__'
# avoid any explicit version checks
val_gen = (hasattr(validation_data, 'next') or
hasattr(validation_data, '__next__'))
hasattr(validation_data, '__next__') or
isinstance(validation_data, Sequence))
if val_gen and not validation_steps:
raise ValueError('When using a generator for validation data, '
'you must specify a value for '
Expand Down Expand Up @@ -1854,25 +1767,29 @@ def generate_arrays_from_file(path):
val_data += [0.]
for cbk in callbacks:
cbk.validation_data = val_data
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True` may duplicate your data.',
'Please consider using the `keras.utils.Sequence` object.'))
enqueuer = None

try:
enqueuer = GeneratorEnqueuer(generator, pickle_safe=pickle_safe)
enqueuer.start(max_q_size=max_q_size, workers=workers)
if is_sequence:
enqueuer = OrderedEnqueuer(generator, use_multiprocessing=use_multiprocessing)
else:
enqueuer = GeneratorEnqueuer(generator, use_multiprocessing=use_multiprocessing,
wait_time=wait_time)
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()

callback_model.stop_training = False
while epoch < epochs:
callbacks.on_epoch_begin(epoch)
steps_done = 0
batch_index = 0
while steps_done < steps_per_epoch:
generator_output = None
while enqueuer.is_running():
if not enqueuer.queue.empty():
generator_output = enqueuer.queue.get()
break
else:
time.sleep(wait_time)
generator_output = next(output_generator)

if not hasattr(generator_output, '__len__'):
raise ValueError('output of generator should be '
Expand Down Expand Up @@ -1923,9 +1840,9 @@ def generate_arrays_from_file(path):
val_outs = self.evaluate_generator(
validation_data,
validation_steps,
max_q_size=max_q_size,
max_queue_size=max_queue_size,
workers=workers,
pickle_safe=pickle_safe)
use_multiprocessing=use_multiprocessing)
else:
# No need for try/except because
# data has already been validated.
Expand Down Expand Up @@ -1954,7 +1871,7 @@ def generate_arrays_from_file(path):

@interfaces.legacy_generator_methods_support
def evaluate_generator(self, generator, steps,
max_q_size=10, workers=1, pickle_safe=False):
max_queue_size=10, workers=1, use_multiprocessing=False):
"""Evaluates the model on a data generator.

The generator should return the same kind of data
Expand All @@ -1963,12 +1880,15 @@ def evaluate_generator(self, generator, steps,
# Arguments
generator: Generator yielding tuples (inputs, targets)
or (inputs, targets, sample_weights)
or an instance of Sequence (keras.utils.Sequence)
object in order to avoid duplicate data
when using multiprocessing.
steps: Total number of steps (batches of samples)
to yield from `generator` before stopping.
max_q_size: maximum size for the generator queue
max_queue_size: maximum size for the generator queue
workers: maximum number of processes to spin up
when using process based threading
pickle_safe: if True, use process based threading.
use_multiprocessing: if True, use process based threading.
Note that because
this implementation relies on multiprocessing,
you should not pass
Expand All @@ -1992,21 +1912,23 @@ def evaluate_generator(self, generator, steps,
wait_time = 0.01
all_outs = []
batch_sizes = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True` may duplicate your data.',
'Please consider using the `keras.utils.Sequence` object.'))
enqueuer = None

try:
enqueuer = GeneratorEnqueuer(generator, pickle_safe=pickle_safe)
enqueuer.start(workers=workers, max_q_size=max_q_size)
if is_sequence:
enqueuer = OrderedEnqueuer(generator, use_multiprocessing=use_multiprocessing)
else:
enqueuer = GeneratorEnqueuer(generator, use_multiprocessing=use_multiprocessing, wait_time=wait_time)
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()

while steps_done < steps:
generator_output = None
while enqueuer.is_running():
if not enqueuer.queue.empty():
generator_output = enqueuer.queue.get()
break
else:
time.sleep(wait_time)

generator_output = next(output_generator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about failure cases? Example: #6928 it is possible only x, only y, neither x nor y, or a tuple of some other unexpected size gets returned.

At a minimum, check the tuple size and throw an exception if it doesn't match expectations. There are probably other cases like this in this pull request, it might be worth double checking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not hasattr(generator_output, '__len__'):
raise ValueError('output of generator should be a tuple '
'(x, y, sample_weight) '
Expand Down Expand Up @@ -2054,21 +1976,24 @@ def evaluate_generator(self, generator, steps,

@interfaces.legacy_generator_methods_support
def predict_generator(self, generator, steps,
max_q_size=10, workers=1,
pickle_safe=False, verbose=0):
max_queue_size=10, workers=1,
use_multiprocessing=False, verbose=0):
"""Generates predictions for the input samples from a data generator.

The generator should return the same kind of data as accepted by
`predict_on_batch`.

# Arguments
generator: Generator yielding batches of input samples.
generator: Generator yielding batches of input samples
or an instance of Sequence (keras.utils.Sequence)
object in order to avoid duplicate data
when using multiprocessing.
steps: Total number of steps (batches of samples)
to yield from `generator` before stopping.
max_q_size: Maximum size for the generator queue.
max_queue_size: Maximum size for the generator queue.
workers: Maximum number of processes to spin up
when using process based threading
pickle_safe: If `True`, use process based threading.
use_multiprocessing: If `True`, use process based threading.
Note that because
this implementation relies on multiprocessing,
you should not pass
Expand All @@ -2089,24 +2014,27 @@ def predict_generator(self, generator, steps,
steps_done = 0
wait_time = 0.01
all_outs = []
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing:
warnings.warn(
UserWarning('Using a generator with `use_multiprocessing=True` may duplicate your data.',
'Please consider using the `keras.utils.Sequence` object.'))
enqueuer = None

try:
enqueuer = GeneratorEnqueuer(generator, pickle_safe=pickle_safe)
enqueuer.start(workers=workers, max_q_size=max_q_size)
if is_sequence:
enqueuer = OrderedEnqueuer(generator, use_multiprocessing=use_multiprocessing)
else:
enqueuer = GeneratorEnqueuer(generator, use_multiprocessing=use_multiprocessing,
wait_time=wait_time)
enqueuer.start(workers=workers, max_queue_size=max_queue_size)
output_generator = enqueuer.get()

if verbose == 1:
progbar = Progbar(target=steps)

while steps_done < steps:
generator_output = None
while enqueuer.is_running():
if not enqueuer.queue.empty():
generator_output = enqueuer.queue.get()
break
else:
time.sleep(wait_time)

generator_output = next(output_generator)
if isinstance(generator_output, tuple):
# Compatibility with the generators
# used for training.
Expand Down
4 changes: 3 additions & 1 deletion keras/legacy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,9 @@ def generator_methods_args_preprocessor(args, kwargs):
('val_samples', 'steps'),
('nb_epoch', 'epochs'),
('nb_val_samples', 'validation_steps'),
('nb_worker', 'workers')],
('nb_worker', 'workers'),
('pickle_safe', 'use_multiprocessing'),
('max_q_size', 'max_queue_size')],
preprocessor=generator_methods_args_preprocessor)


Expand Down
Loading