From 11edd92775558faa7a5edec5fe055703f63b378c Mon Sep 17 00:00:00 2001 From: NanoNabla <43477372+NanoNabla@users.noreply.github.com> Date: Tue, 16 Feb 2021 12:37:06 +0100 Subject: [PATCH 1/4] implement distributed training using horovod --- doc/TRAINING.rst | 15 + setup.py | 10 + training/deepspeech_training/train.py | 347 +++++++++++++++++-- training/deepspeech_training/util/config.py | 31 +- training/deepspeech_training/util/feeding.py | 27 +- training/deepspeech_training/util/flags.py | 2 + 6 files changed, 397 insertions(+), 35 deletions(-) diff --git a/doc/TRAINING.rst b/doc/TRAINING.rst index a5a08e240f..d193c2d52c 100644 --- a/doc/TRAINING.rst +++ b/doc/TRAINING.rst @@ -196,6 +196,21 @@ python3 DeepSpeech.py --train_files ./train.csv --dev_files ./dev.csv --test_fil On a Volta generation V100 GPU, automatic mixed precision speeds up DeepSpeech training and evaluation by ~30%-40%. +Distributed training using Horovod +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you have a capable compute architecture, we offer the opportunity to distribute the training using `Horovod `_. A fast network is recommended. +Horovod is capable of using MPI and NVIDIA's NCCL for highly optimized inter-process communication. +It also offers Gloo as an easy-to-setup communication backend. + +For more information about setup or tuning of Horovod please visit `Horovod's Github `_. + +To train on 4 machines using 4 GPUs each: + +.. code-block:: bash + + horovodrun -np 16 -H server1:4,server2:4,server3:4,server4:4 python3 DeepSpeech.py --train_files [...] --horovod + Checkpointing ^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index b16e655289..f4487b6669 100644 --- a/setup.py +++ b/setup.py @@ -76,6 +76,10 @@ def main(): 'tensorflow == 1.15.4' ] + horovod_pypi_dep = [ + 'horovod' + ] + # Due to pip craziness environment variables are the only consistent way to # get options into this script when doing `pip install`. tc_decoder_artifacts_root = os.environ.get('DECODER_ARTIFACTS_ROOT', '') @@ -94,6 +98,12 @@ def main(): else: install_requires = install_requires + tensorflow_pypi_dep + if os.environ.get('DS_NOHOROVOD', ''): + install_requires = install_requires + else: + install_requires = install_requires + horovod_pypi_dep + + setup( name='deepspeech_training', version=version, diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index 94ca7c04d3..d9bab5187f 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -424,7 +424,8 @@ def train(): process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2, reverse=FLAGS.reverse_train, limit=FLAGS.limit_train, - buffering=FLAGS.read_buffer) + buffering=FLAGS.read_buffer, + split_dataset=False) iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), tfv1.data.get_output_shapes(train_set), @@ -442,7 +443,8 @@ def train(): process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, reverse=FLAGS.reverse_dev, limit=FLAGS.limit_dev, - buffering=FLAGS.read_buffer) for source in dev_sources] + buffering=FLAGS.read_buffer, + split_dataset=False) for source in dev_sources] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] if FLAGS.metrics_files: @@ -454,7 +456,8 @@ def train(): process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, reverse=FLAGS.reverse_dev, limit=FLAGS.limit_dev, - buffering=FLAGS.read_buffer) for source in metrics_sources] + buffering=FLAGS.read_buffer, + split_dataset=False) for source in metrics_sources] metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets] # Dropout @@ -677,6 +680,303 @@ def __call__(self, progress, data, **kwargs): log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) log_debug('Session closed.') +def train_with_horovod(): + + import horovod.tensorflow as hvd + + exception_box = ExceptionBox() + + # Create training and validation datasets + train_set = create_dataset(FLAGS.train_files.split(','), + batch_size=FLAGS.train_batch_size, + epochs=FLAGS.epochs, + augmentations=Config.augmentations, + cache_path=FLAGS.feature_cache, + train_phase=True, + exception_box=exception_box, + process_ahead=Config.num_devices * FLAGS.train_batch_size * 2, + reverse=FLAGS.reverse_train, + limit=FLAGS.limit_train, + buffering=FLAGS.read_buffer, + split_dataset=True) + + iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), + tfv1.data.get_output_shapes(train_set), + output_classes=tfv1.data.get_output_classes(train_set)) + + # Make initialization ops for switching between the two sets + train_init_op = iterator.make_initializer(train_set) + + if FLAGS.dev_files: + dev_sources = FLAGS.dev_files.split(',') + dev_sets = [create_dataset([source], + batch_size=FLAGS.dev_batch_size, + train_phase=False, + exception_box=exception_box, + process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2, + reverse=FLAGS.reverse_dev, + limit=FLAGS.limit_dev, + buffering=FLAGS.read_buffer, + split_dataset=True) for source in dev_sources] + dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] + + if FLAGS.metrics_files: + metrics_sources = FLAGS.metrics_files.split(',') + metrics_sets = [create_dataset([source], + batch_size=FLAGS.dev_batch_size, + train_phase=False, + exception_box=exception_box, + process_ahead=Config.num_devices * FLAGS.dev_batch_size * 2, + reverse=FLAGS.reverse_dev, + limit=FLAGS.limit_dev, + buffering=FLAGS.read_buffer, + split_dataset=True) for source in metrics_sources] + metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets] + + # Dropout + dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)] + dropout_feed_dict = { + dropout_rates[0]: FLAGS.dropout_rate, + dropout_rates[1]: FLAGS.dropout_rate2, + dropout_rates[2]: FLAGS.dropout_rate3, + dropout_rates[3]: FLAGS.dropout_rate4, + dropout_rates[4]: FLAGS.dropout_rate5, + dropout_rates[5]: FLAGS.dropout_rate6, + } + no_dropout_feed_dict = { + rate: 0. for rate in dropout_rates + } + + # Building the graph + learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False) + reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction)) + + # Effective batch size in synchronous distributed training is scaled by the number of workers. An increase in learning rate compensates for the increased batch size. + optimizer = create_optimizer(learning_rate_var * hvd.size()) + optimizer = hvd.DistributedOptimizer(optimizer) + + # Enable mixed precision training + if FLAGS.automatic_mixed_precision: + log_info('Enabling automatic mixed precision training.') + optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) + + loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=False) + gradients = optimizer.compute_gradients(loss) + + tfv1.summary.scalar(name='step_loss', tensor=loss, collections=['step_summaries']) + log_grads_and_vars(gradients) + + # global_step is automagically incremented by the optimizer + global_step = tfv1.train.get_or_create_global_step() + apply_gradient_op = optimizer.apply_gradients(gradients, global_step=global_step) + + # Summaries + step_summaries_op = tfv1.summary.merge_all('step_summaries') + step_summary_writers = { + 'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), + 'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120), + 'metrics': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'metrics'), max_queue=120), + } + + human_readable_set_names = { + 'train': 'Training', + 'dev': 'Validation', + 'metrics': 'Metrics', + } + + # Checkpointing + if Config.is_master_process: + checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep) + checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train') + + best_dev_saver = tfv1.train.Saver(max_to_keep=1) + best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev') + + # Save flags next to checkpoints + if not is_remote_path(FLAGS.save_checkpoint_dir): + os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True) + flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt') + with open_remote(flags_file, 'w') as fout: + fout.write(FLAGS.flags_into_string()) + + bcast = hvd.broadcast_global_variables(0) + + with tfv1.Session(config=Config.session_config) as session: + log_debug('Session opened.') + + # Prevent further graph changes + tfv1.get_default_graph().finalize() + + # Load checkpoint or initialize variables + load_or_init_graph_for_training(session) + bcast.run() + + def run_set(set_name, epoch, init_op, dataset=None): + is_train = set_name == 'train' + train_op = apply_gradient_op if is_train else [] + feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict + + total_loss = 0.0 + step_count = 0 + + step_summary_writer = step_summary_writers.get(set_name) + checkpoint_time = time.time() + + if is_train and FLAGS.cache_for_epochs > 0 and FLAGS.feature_cache: + feature_cache_index = FLAGS.feature_cache + '.index' + if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index): + log_info('Invalidating feature cache') + remove_remote(feature_cache_index) # this will let TF also overwrite the related cache data files + + # Setup progress bar + class LossWidget(progressbar.widgets.FormatLabel): + def __init__(self): + progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f') + + def __call__(self, progress, data, **kwargs): + data['mean_loss'] = total_loss / step_count if step_count else 0.0 + return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs) + + if Config.is_master_process: + # TODO endl seems not to work with horovod + prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name]) + widgets = [' | ', progressbar.widgets.Timer(), + ' | Steps: ', progressbar.widgets.Counter(), + ' | ', LossWidget()] + suffix = ' | Dataset: {}'.format(dataset) if dataset else None + pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start() + + # Initialize iterator to the appropriate dataset + session.run(init_op) + + # Batch loop + while True: + try: + _, current_step, batch_loss, problem_files, step_summary = \ + session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], + feed_dict=feed_dict) + exception_box.raise_if_set() + except tf.errors.OutOfRangeError: + exception_box.raise_if_set() + break + + if problem_files.size > 0: + problem_files = [f.decode('utf8') for f in problem_files[..., 0]] + log_error('The following files caused an infinite (or NaN) ' + 'loss: {}'.format(','.join(problem_files))) + + total_loss += batch_loss + step_count += 1 + + if Config.is_master_process: + pbar.update(step_count) + step_summary_writer.add_summary(step_summary, current_step) + + if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs: + checkpoint_saver.save(session, checkpoint_path, global_step=current_step) + checkpoint_time = time.time() + + pbar.finish() + mean_loss = total_loss / step_count if step_count > 0 else 0.0 + return mean_loss, step_count + + log_info('STARTING Optimization') + train_start_time = datetime.utcnow() + best_dev_loss = float('inf') + dev_losses = [] + epochs_without_improvement = 0 + try: + for epoch in range(FLAGS.epochs): + # Training + if Config.is_master_process: + log_progress('Training epoch %d...' % epoch) + train_loss, _ = run_set('train', epoch, train_init_op) + if Config.is_master_process: + log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss)) + checkpoint_saver.save(session, checkpoint_path, global_step=global_step) + + if FLAGS.dev_files: + # Validation + dev_loss = 0.0 + total_steps = 0 + for source, init_op in zip(dev_sources, dev_init_ops): + if Config.is_master_process: + log_progress('Validating epoch %d on %s...' % (epoch, source)) + set_loss, steps = run_set('dev', epoch, init_op, dataset=source) + dev_loss += set_loss * steps + total_steps += steps + if Config.is_master_process: + log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss)) + + dev_loss = dev_loss / total_steps + dev_losses.append(dev_loss) + + # Count epochs without an improvement for early stopping and reduction of learning rate on a plateau + # the improvement has to be greater than FLAGS.es_min_delta + if dev_loss > best_dev_loss - FLAGS.es_min_delta: + epochs_without_improvement += 1 + else: + epochs_without_improvement = 0 + + if Config.is_master_process: + # Save new best model + if dev_loss < best_dev_loss: + best_dev_loss = dev_loss + save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, + latest_filename='best_dev_checkpoint') + log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) + + # Early stopping + if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs: + if Config.is_master_process: + log_info('Early stop triggered as the loss did not improve the last {} epochs'.format( + epochs_without_improvement)) + break + + # Reduce learning rate on plateau + # If the learning rate was reduced and there is still no improvement + # wait FLAGS.plateau_epochs before the learning rate is reduced again + if ( + FLAGS.reduce_lr_on_plateau + and epochs_without_improvement > 0 + and epochs_without_improvement % FLAGS.plateau_epochs == 0 + ): + # Reload checkpoint that we use the best_dev weights again + reload_best_checkpoint(session) + + # Reduce learning rate + session.run(reduce_learning_rate_op) + current_learning_rate = learning_rate_var.eval() + if Config.is_master_process: + log_info('Encountered a plateau, reducing learning rate to {}'.format( + current_learning_rate)) + + # Overwrite best checkpoint with new learning rate value + save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, + latest_filename='best_dev_checkpoint') + log_info("Saved best validating model with reduced learning rate to: %s" % (save_path)) + + if FLAGS.metrics_files: + # Read only metrics, not affecting best validation loss tracking + for source, init_op in zip(metrics_sources, metrics_init_ops): + if Config.is_master_process: + log_progress('Metrics for epoch %d on %s...' % (epoch, source)) + set_loss, _ = run_set('metrics', epoch, init_op, dataset=source) + if Config.is_master_process: + log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss)) + + if Config.is_master_process: + print('-' * 80) + + + except KeyboardInterrupt: + pass + if Config.is_master_process: + log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) + if Config.is_master_process: + log_debug('Session closed.') + + def test(): samples = evaluate(FLAGS.test_files.split(','), create_model) @@ -951,30 +1251,35 @@ def main(_): if FLAGS.train_files: tfv1.reset_default_graph() tfv1.set_random_seed(FLAGS.random_seed) - train() - if FLAGS.test_files: - tfv1.reset_default_graph() - test() + if FLAGS.horovod: + train_with_horovod() + else: + train() - if FLAGS.export_dir and not FLAGS.export_zip: - tfv1.reset_default_graph() - export() + if Config.is_master_process: + if FLAGS.test_files: + tfv1.reset_default_graph() + test() - if FLAGS.export_zip: - tfv1.reset_default_graph() - FLAGS.export_tflite = True + if FLAGS.export_dir and not FLAGS.export_zip: + tfv1.reset_default_graph() + export() - if listdir_remote(FLAGS.export_dir): - log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir)) - sys.exit(1) + if FLAGS.export_zip: + tfv1.reset_default_graph() + FLAGS.export_tflite = True - export() - package_zip() + if listdir_remote(FLAGS.export_dir): + log_error('Directory {} is not empty, please fix this.'.format(FLAGS.export_dir)) + sys.exit(1) - if FLAGS.one_shot_infer: - tfv1.reset_default_graph() - do_single_file_inference(FLAGS.one_shot_infer) + export() + package_zip() + + if FLAGS.one_shot_infer: + tfv1.reset_default_graph() + do_single_file_inference(FLAGS.one_shot_infer) def run_script(): diff --git a/training/deepspeech_training/util/config.py b/training/deepspeech_training/util/config.py index 358aa6ab98..51fd012ce5 100755 --- a/training/deepspeech_training/util/config.py +++ b/training/deepspeech_training/util/config.py @@ -79,12 +79,33 @@ def initialize_globals(): # CPU device c.cpu_device = '/cpu:0' - # Available GPU devices - c.available_devices = get_available_gpus(c.session_config) + if FLAGS.horovod: + try: + import horovod.tensorflow as hvd + except ImportError as e: + print( + "Error importing Horovod. Did you installed DeepSpeech with -DNOHOROVOD? " + "If you do not want to use horovod, use 'from deepspeech_training import train'") + raise e + + hvd.init() + + # Pin GPU to be used to process local rank (one GPU per process) + c.session_config.gpu_options.visible_device_list = str(hvd.local_rank()) + c.num_devices = hvd.size() + c.is_master_process = True if hvd.rank() == 0 else False + else: + # # Available GPU devices + c.available_devices = get_available_gpus(c.session_config) + + # If there is no GPU available, we fall back to CPU based operation + if not c.available_devices: + c.available_devices = [c.cpu_device] + + c.num_devices = len(c.available_devices) - # If there is no GPU available, we fall back to CPU based operation - if not c.available_devices: - c.available_devices = [c.cpu_device] + # If there are no horovod processes the only one should handled like horovod master + c.is_master_process = True if FLAGS.bytes_output_mode: c.alphabet = UTF8Alphabet() diff --git a/training/deepspeech_training/util/feeding.py b/training/deepspeech_training/util/feeding.py index 30a2b2f470..739c040c2e 100644 --- a/training/deepspeech_training/util/feeding.py +++ b/training/deepspeech_training/util/feeding.py @@ -94,7 +94,8 @@ def create_dataset(sources, limit=0, exception_box=None, process_ahead=None, - buffering=1 * MEGABYTE): + buffering=1 * MEGABYTE, + split_dataset=False): epoch_counter = Counter() # survives restarts of the dataset and its generator def generate_values(): @@ -135,17 +136,25 @@ def batch_fn(sample_ids, features, features_len, transcripts): process_fn = partial(entry_to_features, train_phase=train_phase, augmentations=augmentations) - dataset = (tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box), - output_types=(tf.string, tf.float32, tf.int32, - (tf.int64, tf.int32, tf.int64), tf.float64)) - .map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)) + dataset = tf.data.Dataset.from_generator(remember_exception(generate_values, exception_box), + output_types=(tf.string, tf.float32, tf.int32, + (tf.int64, tf.int32, tf.int64), tf.float64)) + if split_dataset: + # Using horovod Iterator.get_next() is not aware of different devices. + # A.shard(n, i) will contain all elements of A whose index mod n = i. + import horovod.tensorflow as hvd + dataset = dataset.shard(hvd.size(), hvd.rank()) + dataset = dataset.map(process_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) if cache_path: dataset = dataset.cache(cache_path) - dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn) - .prefetch(len(Config.available_devices))) + dataset = (dataset.window(batch_size, drop_remainder=train_phase).flat_map(batch_fn)) + if split_dataset: + #TODO is there a way to get a proper value? + dataset = dataset.prefetch(2) + else: + dataset = dataset.prefetch(Config.num_devices) return dataset - def split_audio_file(audio_path, audio_format=DEFAULT_FORMAT, batch_size=1, @@ -178,5 +187,5 @@ def create_batch_set(bs, criteria): ods = create_batch_set(outlier_batch_size, lambda start, end, f, fl: end - start > int(outlier_duration_ms)) dataset = nds.concatenate(ods) - dataset = dataset.prefetch(len(Config.available_devices)) + dataset = dataset.prefetch(Config.num_devices) return dataset diff --git a/training/deepspeech_training/util/flags.py b/training/deepspeech_training/util/flags.py index fcbd6dd06a..2a28b4a9a9 100644 --- a/training/deepspeech_training/util/flags.py +++ b/training/deepspeech_training/util/flags.py @@ -69,6 +69,8 @@ def create_flags(): f.DEFINE_boolean('train_cudnn', False, 'use CuDNN RNN backend for training on GPU. Note that checkpoints created with this flag can only be used with CuDNN RNN, i.e. fine tuning on a CPU device will not work') f.DEFINE_boolean('automatic_mixed_precision', False, 'whether to allow automatic mixed precision training. USE OF THIS FLAG IS UNSUPPORTED. Checkpoints created with automatic mixed precision training will not be usable without mixed precision.') + f.DEFINE_boolean('horovod', False, 'use horovod for training on multiple gpus') + # Sample limits f.DEFINE_integer('limit_train', 0, 'maximum number of elements to use from train set - 0 means no limit') From 458b2e28a8ceebf477d14ef0f6abf6e88c6fbc65 Mon Sep 17 00:00:00 2001 From: NanoNabla <43477372+NanoNabla@users.noreply.github.com> Date: Tue, 16 Feb 2021 15:13:36 +0100 Subject: [PATCH 2/4] suggestions by lissyx in #3533 --- doc/TRAINING.rst | 4 ++-- setup.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/doc/TRAINING.rst b/doc/TRAINING.rst index d193c2d52c..863ef477a9 100644 --- a/doc/TRAINING.rst +++ b/doc/TRAINING.rst @@ -199,9 +199,9 @@ On a Volta generation V100 GPU, automatic mixed precision speeds up DeepSpeech t Distributed training using Horovod ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -If you have a capable compute architecture, we offer the opportunity to distribute the training using `Horovod `_. A fast network is recommended. +If you have a capable compute architecture, it is possible to distribute the training using `Horovod `_. A fast network is recommended. Horovod is capable of using MPI and NVIDIA's NCCL for highly optimized inter-process communication. -It also offers Gloo as an easy-to-setup communication backend. +It also offers `Gloo `_ as an easy-to-setup communication backend. For more information about setup or tuning of Horovod please visit `Horovod's Github `_. diff --git a/setup.py b/setup.py index f4487b6669..8b13956f72 100644 --- a/setup.py +++ b/setup.py @@ -77,7 +77,7 @@ def main(): ] horovod_pypi_dep = [ - 'horovod' + 'horovod[tensorflow] == 0.21.3' ] # Due to pip craziness environment variables are the only consistent way to @@ -98,10 +98,10 @@ def main(): else: install_requires = install_requires + tensorflow_pypi_dep - if os.environ.get('DS_NOHOROVOD', ''): - install_requires = install_requires - else: + if os.environ.get('DS_WITH_HOROVOD', ''): install_requires = install_requires + horovod_pypi_dep + else: + install_requires = install_requires setup( From 7db6b6282c500284eb5c1a53fa7121483b76f969 Mon Sep 17 00:00:00 2001 From: NanoNabla <43477372+NanoNabla@users.noreply.github.com> Date: Tue, 16 Feb 2021 15:35:37 +0100 Subject: [PATCH 3/4] merge train_with_horovod into train --- training/deepspeech_training/train.py | 364 ++++---------------------- 1 file changed, 53 insertions(+), 311 deletions(-) diff --git a/training/deepspeech_training/train.py b/training/deepspeech_training/train.py index d9bab5187f..4c86828305 100644 --- a/training/deepspeech_training/train.py +++ b/training/deepspeech_training/train.py @@ -413,280 +413,12 @@ def log_grads_and_vars(grads_and_vars): def train(): exception_box = ExceptionBox() - # Create training and validation datasets - train_set = create_dataset(FLAGS.train_files.split(','), - batch_size=FLAGS.train_batch_size, - epochs=FLAGS.epochs, - augmentations=Config.augmentations, - cache_path=FLAGS.feature_cache, - train_phase=True, - exception_box=exception_box, - process_ahead=len(Config.available_devices) * FLAGS.train_batch_size * 2, - reverse=FLAGS.reverse_train, - limit=FLAGS.limit_train, - buffering=FLAGS.read_buffer, - split_dataset=False) - - iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), - tfv1.data.get_output_shapes(train_set), - output_classes=tfv1.data.get_output_classes(train_set)) - - # Make initialization ops for switching between the two sets - train_init_op = iterator.make_initializer(train_set) - - if FLAGS.dev_files: - dev_sources = FLAGS.dev_files.split(',') - dev_sets = [create_dataset([source], - batch_size=FLAGS.dev_batch_size, - train_phase=False, - exception_box=exception_box, - process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, - reverse=FLAGS.reverse_dev, - limit=FLAGS.limit_dev, - buffering=FLAGS.read_buffer, - split_dataset=False) for source in dev_sources] - dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] - - if FLAGS.metrics_files: - metrics_sources = FLAGS.metrics_files.split(',') - metrics_sets = [create_dataset([source], - batch_size=FLAGS.dev_batch_size, - train_phase=False, - exception_box=exception_box, - process_ahead=len(Config.available_devices) * FLAGS.dev_batch_size * 2, - reverse=FLAGS.reverse_dev, - limit=FLAGS.limit_dev, - buffering=FLAGS.read_buffer, - split_dataset=False) for source in metrics_sources] - metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets] - - # Dropout - dropout_rates = [tfv1.placeholder(tf.float32, name='dropout_{}'.format(i)) for i in range(6)] - dropout_feed_dict = { - dropout_rates[0]: FLAGS.dropout_rate, - dropout_rates[1]: FLAGS.dropout_rate2, - dropout_rates[2]: FLAGS.dropout_rate3, - dropout_rates[3]: FLAGS.dropout_rate4, - dropout_rates[4]: FLAGS.dropout_rate5, - dropout_rates[5]: FLAGS.dropout_rate6, - } - no_dropout_feed_dict = { - rate: 0. for rate in dropout_rates - } - - # Building the graph - learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False) - reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction)) - optimizer = create_optimizer(learning_rate_var) - - # Enable mixed precision training - if FLAGS.automatic_mixed_precision: - log_info('Enabling automatic mixed precision training.') - optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) - - gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates) - - # Average tower gradients across GPUs - avg_tower_gradients = average_gradients(gradients) - log_grads_and_vars(avg_tower_gradients) - - # global_step is automagically incremented by the optimizer - global_step = tfv1.train.get_or_create_global_step() - apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) - - # Summaries - step_summaries_op = tfv1.summary.merge_all('step_summaries') - step_summary_writers = { - 'train': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'train'), max_queue=120), - 'dev': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'dev'), max_queue=120), - 'metrics': tfv1.summary.FileWriter(os.path.join(FLAGS.summary_dir, 'metrics'), max_queue=120), - } - - human_readable_set_names = { - 'train': 'Training', - 'dev': 'Validation', - 'metrics': 'Metrics', - } - - # Checkpointing - checkpoint_saver = tfv1.train.Saver(max_to_keep=FLAGS.max_to_keep) - checkpoint_path = os.path.join(FLAGS.save_checkpoint_dir, 'train') - - best_dev_saver = tfv1.train.Saver(max_to_keep=1) - best_dev_path = os.path.join(FLAGS.save_checkpoint_dir, 'best_dev') - - # Save flags next to checkpoints - if not is_remote_path(FLAGS.save_checkpoint_dir): - os.makedirs(FLAGS.save_checkpoint_dir, exist_ok=True) - flags_file = os.path.join(FLAGS.save_checkpoint_dir, 'flags.txt') - with open_remote(flags_file, 'w') as fout: - fout.write(FLAGS.flags_into_string()) - - with tfv1.Session(config=Config.session_config) as session: - log_debug('Session opened.') - - # Prevent further graph changes - tfv1.get_default_graph().finalize() - - # Load checkpoint or initialize variables - load_or_init_graph_for_training(session) - - def run_set(set_name, epoch, init_op, dataset=None): - is_train = set_name == 'train' - train_op = apply_gradient_op if is_train else [] - feed_dict = dropout_feed_dict if is_train else no_dropout_feed_dict - - total_loss = 0.0 - step_count = 0 - - step_summary_writer = step_summary_writers.get(set_name) - checkpoint_time = time.time() - - if is_train and FLAGS.cache_for_epochs > 0 and FLAGS.feature_cache: - feature_cache_index = FLAGS.feature_cache + '.index' - if epoch % FLAGS.cache_for_epochs == 0 and os.path.isfile(feature_cache_index): - log_info('Invalidating feature cache') - remove_remote(feature_cache_index) # this will let TF also overwrite the related cache data files - - # Setup progress bar - class LossWidget(progressbar.widgets.FormatLabel): - def __init__(self): - progressbar.widgets.FormatLabel.__init__(self, format='Loss: %(mean_loss)f') - - def __call__(self, progress, data, **kwargs): - data['mean_loss'] = total_loss / step_count if step_count else 0.0 - return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs) - - prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name]) - widgets = [' | ', progressbar.widgets.Timer(), - ' | Steps: ', progressbar.widgets.Counter(), - ' | ', LossWidget()] - suffix = ' | Dataset: {}'.format(dataset) if dataset else None - pbar = create_progressbar(prefix=prefix, widgets=widgets, suffix=suffix).start() - - # Initialize iterator to the appropriate dataset - session.run(init_op) - - # Batch loop - while True: - try: - _, current_step, batch_loss, problem_files, step_summary = \ - session.run([train_op, global_step, loss, non_finite_files, step_summaries_op], - feed_dict=feed_dict) - exception_box.raise_if_set() - except tf.errors.OutOfRangeError: - exception_box.raise_if_set() - break - - if problem_files.size > 0: - problem_files = [f.decode('utf8') for f in problem_files[..., 0]] - log_error('The following files caused an infinite (or NaN) ' - 'loss: {}'.format(','.join(problem_files))) - - total_loss += batch_loss - step_count += 1 - - pbar.update(step_count) - - step_summary_writer.add_summary(step_summary, current_step) - - if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs: - checkpoint_saver.save(session, checkpoint_path, global_step=current_step) - checkpoint_time = time.time() - - pbar.finish() - mean_loss = total_loss / step_count if step_count > 0 else 0.0 - return mean_loss, step_count - - log_info('STARTING Optimization') - train_start_time = datetime.utcnow() - best_dev_loss = float('inf') - dev_losses = [] - epochs_without_improvement = 0 - try: - for epoch in range(FLAGS.epochs): - # Training - log_progress('Training epoch %d...' % epoch) - train_loss, _ = run_set('train', epoch, train_init_op) - log_progress('Finished training epoch %d - loss: %f' % (epoch, train_loss)) - checkpoint_saver.save(session, checkpoint_path, global_step=global_step) - - if FLAGS.dev_files: - # Validation - dev_loss = 0.0 - total_steps = 0 - for source, init_op in zip(dev_sources, dev_init_ops): - log_progress('Validating epoch %d on %s...' % (epoch, source)) - set_loss, steps = run_set('dev', epoch, init_op, dataset=source) - dev_loss += set_loss * steps - total_steps += steps - log_progress('Finished validating epoch %d on %s - loss: %f' % (epoch, source, set_loss)) - - dev_loss = dev_loss / total_steps - dev_losses.append(dev_loss) - - # Count epochs without an improvement for early stopping and reduction of learning rate on a plateau - # the improvement has to be greater than FLAGS.es_min_delta - if dev_loss > best_dev_loss - FLAGS.es_min_delta: - epochs_without_improvement += 1 - else: - epochs_without_improvement = 0 - - # Save new best model - if dev_loss < best_dev_loss: - best_dev_loss = dev_loss - save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') - log_info("Saved new best validating model with loss %f to: %s" % (best_dev_loss, save_path)) - - # Early stopping - if FLAGS.early_stop and epochs_without_improvement == FLAGS.es_epochs: - log_info('Early stop triggered as the loss did not improve the last {} epochs'.format( - epochs_without_improvement)) - break - - # Reduce learning rate on plateau - # If the learning rate was reduced and there is still no improvement - # wait FLAGS.plateau_epochs before the learning rate is reduced again - if ( - FLAGS.reduce_lr_on_plateau - and epochs_without_improvement > 0 - and epochs_without_improvement % FLAGS.plateau_epochs == 0 - ): - # Reload checkpoint that we use the best_dev weights again - reload_best_checkpoint(session) - - # Reduce learning rate - session.run(reduce_learning_rate_op) - current_learning_rate = learning_rate_var.eval() - log_info('Encountered a plateau, reducing learning rate to {}'.format( - current_learning_rate)) - - # Overwrite best checkpoint with new learning rate value - save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') - log_info("Saved best validating model with reduced learning rate to: %s" % (save_path)) - - if FLAGS.metrics_files: - # Read only metrics, not affecting best validation loss tracking - for source, init_op in zip(metrics_sources, metrics_init_ops): - log_progress('Metrics for epoch %d on %s...' % (epoch, source)) - set_loss, _ = run_set('metrics', epoch, init_op, dataset=source) - log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss)) - - print('-' * 80) - - - except KeyboardInterrupt: - pass - log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) - log_debug('Session closed.') - -def train_with_horovod(): - - import horovod.tensorflow as hvd - - exception_box = ExceptionBox() + if FLAGS.horovod: + import horovod.tensorflow as hvd # Create training and validation datasets + split_dataset = FLAGS.horovod + train_set = create_dataset(FLAGS.train_files.split(','), batch_size=FLAGS.train_batch_size, epochs=FLAGS.epochs, @@ -698,7 +430,7 @@ def train_with_horovod(): reverse=FLAGS.reverse_train, limit=FLAGS.limit_train, buffering=FLAGS.read_buffer, - split_dataset=True) + split_dataset=split_dataset) iterator = tfv1.data.Iterator.from_structure(tfv1.data.get_output_types(train_set), tfv1.data.get_output_shapes(train_set), @@ -717,7 +449,7 @@ def train_with_horovod(): reverse=FLAGS.reverse_dev, limit=FLAGS.limit_dev, buffering=FLAGS.read_buffer, - split_dataset=True) for source in dev_sources] + split_dataset=split_dataset) for source in dev_sources] dev_init_ops = [iterator.make_initializer(dev_set) for dev_set in dev_sets] if FLAGS.metrics_files: @@ -730,7 +462,7 @@ def train_with_horovod(): reverse=FLAGS.reverse_dev, limit=FLAGS.limit_dev, buffering=FLAGS.read_buffer, - split_dataset=True) for source in metrics_sources] + split_dataset=split_dataset) for source in metrics_sources] metrics_init_ops = [iterator.make_initializer(metrics_set) for metrics_set in metrics_sets] # Dropout @@ -750,25 +482,38 @@ def train_with_horovod(): # Building the graph learning_rate_var = tfv1.get_variable('learning_rate', initializer=FLAGS.learning_rate, trainable=False) reduce_learning_rate_op = learning_rate_var.assign(tf.multiply(learning_rate_var, FLAGS.plateau_reduction)) - - # Effective batch size in synchronous distributed training is scaled by the number of workers. An increase in learning rate compensates for the increased batch size. - optimizer = create_optimizer(learning_rate_var * hvd.size()) - optimizer = hvd.DistributedOptimizer(optimizer) + if FLAGS.horovod: + # Effective batch size in synchronous distributed training is scaled by the number of workers. An increase in learning rate compensates for the increased batch size. + optimizer = create_optimizer(learning_rate_var * hvd.size()) + optimizer = hvd.DistributedOptimizer(optimizer) + else: + optimizer = create_optimizer(learning_rate_var) # Enable mixed precision training if FLAGS.automatic_mixed_precision: log_info('Enabling automatic mixed precision training.') optimizer = tfv1.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) - loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=False) - gradients = optimizer.compute_gradients(loss) + if FLAGS.horovod: + loss, non_finite_files = calculate_mean_edit_distance_and_loss(iterator, dropout_rates, reuse=False) + gradients = optimizer.compute_gradients(loss) - tfv1.summary.scalar(name='step_loss', tensor=loss, collections=['step_summaries']) - log_grads_and_vars(gradients) + tfv1.summary.scalar(name='step_loss', tensor=loss, collections=['step_summaries']) + log_grads_and_vars(gradients) - # global_step is automagically incremented by the optimizer - global_step = tfv1.train.get_or_create_global_step() - apply_gradient_op = optimizer.apply_gradients(gradients, global_step=global_step) + # global_step is automagically incremented by the optimizer + global_step = tfv1.train.get_or_create_global_step() + apply_gradient_op = optimizer.apply_gradients(gradients, global_step=global_step) + else: + gradients, loss, non_finite_files = get_tower_results(iterator, optimizer, dropout_rates) + + # Average tower gradients across GPUs + avg_tower_gradients = average_gradients(gradients) + log_grads_and_vars(avg_tower_gradients) + + # global_step is automagically incremented by the optimizer + global_step = tfv1.train.get_or_create_global_step() + apply_gradient_op = optimizer.apply_gradients(avg_tower_gradients, global_step=global_step) # Summaries step_summaries_op = tfv1.summary.merge_all('step_summaries') @@ -799,7 +544,8 @@ def train_with_horovod(): with open_remote(flags_file, 'w') as fout: fout.write(FLAGS.flags_into_string()) - bcast = hvd.broadcast_global_variables(0) + if FLAGS.horovod: + bcast = hvd.broadcast_global_variables(0) with tfv1.Session(config=Config.session_config) as session: log_debug('Session opened.') @@ -809,7 +555,8 @@ def train_with_horovod(): # Load checkpoint or initialize variables load_or_init_graph_for_training(session) - bcast.run() + if FLAGS.horovod: + bcast.run() def run_set(set_name, epoch, init_op, dataset=None): is_train = set_name == 'train' @@ -838,7 +585,6 @@ def __call__(self, progress, data, **kwargs): return progressbar.widgets.FormatLabel.__call__(self, progress, data, **kwargs) if Config.is_master_process: - # TODO endl seems not to work with horovod prefix = 'Epoch {} | {:>10}'.format(epoch, human_readable_set_names[set_name]) widgets = [' | ', progressbar.widgets.Timer(), ' | Steps: ', progressbar.widgets.Counter(), @@ -870,13 +616,15 @@ def __call__(self, progress, data, **kwargs): if Config.is_master_process: pbar.update(step_count) + step_summary_writer.add_summary(step_summary, current_step) if is_train and FLAGS.checkpoint_secs > 0 and time.time() - checkpoint_time > FLAGS.checkpoint_secs: checkpoint_saver.save(session, checkpoint_path, global_step=current_step) checkpoint_time = time.time() - pbar.finish() + if Config.is_master_process: + pbar.finish() mean_loss = total_loss / step_count if step_count > 0 else 0.0 return mean_loss, step_count @@ -937,9 +685,9 @@ def __call__(self, progress, data, **kwargs): # If the learning rate was reduced and there is still no improvement # wait FLAGS.plateau_epochs before the learning rate is reduced again if ( - FLAGS.reduce_lr_on_plateau - and epochs_without_improvement > 0 - and epochs_without_improvement % FLAGS.plateau_epochs == 0 + FLAGS.reduce_lr_on_plateau + and epochs_without_improvement > 0 + and epochs_without_improvement % FLAGS.plateau_epochs == 0 ): # Reload checkpoint that we use the best_dev weights again reload_best_checkpoint(session) @@ -949,33 +697,30 @@ def __call__(self, progress, data, **kwargs): current_learning_rate = learning_rate_var.eval() if Config.is_master_process: log_info('Encountered a plateau, reducing learning rate to {}'.format( - current_learning_rate)) + current_learning_rate)) # Overwrite best checkpoint with new learning rate value save_path = best_dev_saver.save(session, best_dev_path, global_step=global_step, latest_filename='best_dev_checkpoint') log_info("Saved best validating model with reduced learning rate to: %s" % (save_path)) - if FLAGS.metrics_files: - # Read only metrics, not affecting best validation loss tracking - for source, init_op in zip(metrics_sources, metrics_init_ops): - if Config.is_master_process: - log_progress('Metrics for epoch %d on %s...' % (epoch, source)) - set_loss, _ = run_set('metrics', epoch, init_op, dataset=source) - if Config.is_master_process: - log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss)) + if FLAGS.metrics_files: + # Read only metrics, not affecting best validation loss tracking + for source, init_op in zip(metrics_sources, metrics_init_ops): + if Config.is_master_process: + log_progress('Metrics for epoch %d on %s...' % (epoch, source)) + set_loss, _ = run_set('metrics', epoch, init_op, dataset=source) + if Config.is_master_process: + log_progress('Metrics for epoch %d on %s - loss: %f' % (epoch, source, set_loss)) - if Config.is_master_process: - print('-' * 80) + print('-' * 80) except KeyboardInterrupt: pass if Config.is_master_process: log_info('FINISHED optimization in {}'.format(datetime.utcnow() - train_start_time)) - if Config.is_master_process: - log_debug('Session closed.') - + log_debug('Session closed.') def test(): @@ -1252,10 +997,7 @@ def main(_): tfv1.reset_default_graph() tfv1.set_random_seed(FLAGS.random_seed) - if FLAGS.horovod: - train_with_horovod() - else: - train() + train() if Config.is_master_process: if FLAGS.test_files: From 329bf876069720cf05b4e4700e6d0dde104b6bac Mon Sep 17 00:00:00 2001 From: NanoNabla <43477372+NanoNabla@users.noreply.github.com> Date: Thu, 18 Feb 2021 13:19:52 +0100 Subject: [PATCH 4/4] improve horovod docu --- doc/TRAINING.rst | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/doc/TRAINING.rst b/doc/TRAINING.rst index 863ef477a9..5aaf0c59c1 100644 --- a/doc/TRAINING.rst +++ b/doc/TRAINING.rst @@ -203,9 +203,15 @@ If you have a capable compute architecture, it is possible to distribute the tra Horovod is capable of using MPI and NVIDIA's NCCL for highly optimized inter-process communication. It also offers `Gloo `_ as an easy-to-setup communication backend. -For more information about setup or tuning of Horovod please visit `Horovod's Github `_. +For more information about setup or tuning of Horovod please visit `Horovod's documentation `_. -To train on 4 machines using 4 GPUs each: +Horovod is expected to run on heterogeneous systems (e.g. different number and model type of GPUs per machine). +However, this can cause unpredictable problems and user interaction in training code is needed. +Therefore, we do only support homogenous systems, which means same hardware and also same software configuration (OS, drivers, MPI, NCCL, TensorFlow, ...) on each machine. +The only exception is different number of GPUs per machine, since this can be controlled by ``horovodrun -H``. + +Detailed documentation how to run Horovod is provided `here `_. +The short command to train on 4 machines using 4 GPUs each: .. code-block:: bash