diff --git a/example/distributed_training-horovod/README.md b/example/distributed_training-horovod/README.md new file mode 100644 index 000000000000..c4776044a385 --- /dev/null +++ b/example/distributed_training-horovod/README.md @@ -0,0 +1,201 @@ + + + + + + + + + + + + + + + + + +# Distributed Training using MXNet with Horovod +[Horovod](https://github.com/horovod/horovod) is a distributed training framework that demonstrates +excellent scaling efficiency for dense models running on a large number of nodes. It currently +supports mainstream deep learning frameworks such as MXNet, TensorFlow, Keras, and PyTorch. +It is created at Uber and currently hosted by the [Linux Foundation Deep Learning](https://lfdl.io)(LF DL). + +MXNet is supported in Horovod 0.16.0 [release](https://eng.uber.com/horovod-pyspark-apache-mxnet-support/). + +## What's New? +Compared with the standard distributed training script in MXNet which uses parameter server to +distribute and aggregate parameters, Horovod uses ring allreduce and/or tree-based allreduce algorithm +to communicate parameters between workers. There is no dedicated server and the communication data size +between workers does not depend on the number of workers. Therefore, it scales well in the case where +there are a large number of workers and network bandwidth is the bottleneck. + +# Install +## Install MXNet +```bash +$ pip install mxnet +``` +**Note**: There is a [known issue](https://github.com/horovod/horovod/issues/884) when running Horovod with MXNet on a Linux system with GCC version 5.X and above. We recommend users to build MXNet from source following this [guide](https://mxnet.incubator.apache.org/install/build_from_source.html) as a workaround for now. Also mxnet-mkl package in 1.4.0 release does not support Horovod. + +## Install Horovod +```bash +$ pip install horovod +``` + +This basic installation is good for laptops and for getting to know Horovod. +If you're installing Horovod on a server with GPUs, read the [Horovod on GPU](https://github.com/horovod/horovod/blob/master/docs/gpus.md) page. +If you want to use Docker, read the [Horovod in Docker](https://github.com/horovod/horovod/blob/master/docs/docker.md) page. + +## Install MPI +MPI is required to run distributed training with Horovod. Install [Open MPI](https://www.open-mpi.org/) or another MPI implementation. +Steps to install Open MPI are listed [here](https://www.open-mpi.org/faq/?category=building#easy-build). + +**Note**: Open MPI 3.1.3 has an issue that may cause hangs. It is recommended +to downgrade to Open MPI 3.1.2 or upgrade to Open MPI 4.0.0. + +# Usage + +To run MXNet with Horovod, make the following additions to your training script: + +1. Run `hvd.init()`. + +2. Pin the context to a processor using `hvd.local_rank()`. + Typically, each Horovod worker is associated with one process. The local rank is a unique ID specifically + for all processes running Horovod job on the same node. + +3. Scale the learning rate by number of workers. Effective batch size in synchronous distributed training is scaled by + the number of workers. An increase in learning rate compensates for the increased batch size. + +4. Wrap optimizer in `hvd.DistributedOptimizer`. The distributed optimizer delegates gradient computation + to the original optimizer, averages gradients using *allreduce* or *allgather*, and then applies those averaged + gradients. + +5. Add `hvd.broadcast_parameters` to broadcast initial variable states from rank 0 to all other processes. + This is necessary to ensure consistent initialization of all workers when training is started with random weights or + restored from a checkpoint. + +# Example + +Here we provide the building blocks to train a model using MXNet with Horovod. +The full examples are in [MNIST](gluon_mnist.py) and [ImageNet](resnet50_imagenet.py). + +## Gluon API +```python +from mxnet import autograd, gluon +import mxnet as mx +import horovod.mxnet as hvd + +# Initialize Horovod +hvd.init() + +# Set context to current process +context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(hvd.local_rank()) + +num_workers = hvd.size() + +# Build model +model = ... +model.hybridize() + +# Define hyper parameters +optimizer_params = ... + +# Add Horovod Distributed Optimizer +opt = mx.optimizer.create('sgd', **optimizer_params) +opt = hvd.DistributedOptimizer(opt) + +# Initialize parameters +model.initialize(initializer, ctx=context) + +# Fetch and broadcast parameters +params = model.collect_params() +if params is not None: + hvd.broadcast_parameters(params, root_rank=0) + +# Create trainer and loss function +trainer = gluon.Trainer(params, opt, kvstore=None) +loss_fn = ... + +# Train model +for epoch in range(num_epoch): + train_data.reset() + for nbatch, batch in enumerate(train_data, start=1): + data = batch.data[0].as_in_context(context) + label = batch.label[0].as_in_context(context) + with autograd.record(): + output = model(data.astype(dtype, copy=False)) + loss = loss_fn(output, label) + loss.backward() + trainer.step(batch_size) +``` + +## Module API +```python +import mxnet as mx +import horovod.mxnet as hvd + +# Initialize Horovod +hvd.init() + +# Set context to current process +context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(hvd.local_rank()) +num_workers = hvd.size() + +# Build model +model = ... + +# Define hyper parameters +optimizer_params = ... + +# Add Horovod Distributed Optimizer +opt = mx.optimizer.create('sgd', **optimizer_params) +opt = hvd.DistributedOptimizer(opt) + +# Initialize parameters +initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", + magnitude=2) +model.bind(data_shapes=train_data.provide_data, + label_shapes=train_data.provide_label) +model.init_params(initializer) + +# Fetch and broadcast parameters +(arg_params, aux_params) = model.get_params() +if arg_params: + hvd.broadcast_parameters(arg_params, root_rank=0) +if aux_params: + hvd.broadcast_parameters(aux_params, root_rank=0) +model.set_params(arg_params=arg_params, aux_params=aux_params) + +# Train model +model.fit(train_data, + kvstore=None, + optimizer=opt, + num_epoch=num_epoch) +``` + + +# Running Horovod + +The example commands below show how to run distributed training. See the +[Running Horovod](https://github.com/horovod/horovod/blob/master/docs/running.md) +page for more instructions, including RoCE/InfiniBand tweaks and tips for dealing with hangs. + +1. To run on a machine with 4 CPUs: + +```bash +$ mpirun -np 4 \ + -H localhost:4 \ + -bind-to none -map-by slot \ + python train.py +``` + +2. To run on 2 machines with 4 GPUs each: + +```bash +$ mpirun -np 8 \ + -H server1:4,server2:4 \ + -bind-to none -map-by slot \ + -x NCCL_DEBUG=INFO \ + -mca pml ob1 -mca btl ^openib \ + python train.py +``` \ No newline at end of file diff --git a/example/distributed_training-horovod/gluon_mnist.py b/example/distributed_training-horovod/gluon_mnist.py new file mode 100644 index 000000000000..7e4be58cc2ef --- /dev/null +++ b/example/distributed_training-horovod/gluon_mnist.py @@ -0,0 +1,186 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import logging +import os +import zipfile +import time + +import mxnet as mx +import horovod.mxnet as hvd +from mxnet import autograd, gluon, nd +from mxnet.test_utils import download + +# Training settings +parser = argparse.ArgumentParser(description='MXNet MNIST Example') + +parser.add_argument('--batch-size', type=int, default=64, + help='training batch size (default: 64)') +parser.add_argument('--dtype', type=str, default='float32', + help='training data type (default: float32)') +parser.add_argument('--epochs', type=int, default=5, + help='number of training epochs (default: 5)') +parser.add_argument('--lr', type=float, default=0.01, + help='learning rate (default: 0.01)') +parser.add_argument('--momentum', type=float, default=0.9, + help='SGD momentum (default: 0.9)') +parser.add_argument('--use-gpu', action='store_true', default=False, + help='run training on GPU (default: False)') +args = parser.parse_args() + +logging.basicConfig(level=logging.INFO) +logging.info(args) + + +# Function to get mnist iterator given a rank +def get_mnist_iterator(rank): + data_dir = "data-%d" % rank + if not os.path.isdir(data_dir): + os.makedirs(data_dir) + zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip', + dirname=data_dir) + with zipfile.ZipFile(zip_file_path) as zf: + zf.extractall(data_dir) + + input_shape = (1, 28, 28) + batch_size = args.batch_size + + train_iter = mx.io.MNISTIter( + image="%s/train-images-idx3-ubyte" % data_dir, + label="%s/train-labels-idx1-ubyte" % data_dir, + input_shape=input_shape, + batch_size=batch_size, + shuffle=True, + flat=False, + num_parts=hvd.size(), + part_index=hvd.rank() + ) + + val_iter = mx.io.MNISTIter( + image="%s/t10k-images-idx3-ubyte" % data_dir, + label="%s/t10k-labels-idx1-ubyte" % data_dir, + input_shape=input_shape, + batch_size=batch_size, + flat=False, + ) + + return train_iter, val_iter + + +# Function to define neural network +def conv_nets(): + net = gluon.nn.HybridSequential() + with net.name_scope(): + net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu')) + net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2)) + net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu')) + net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2)) + net.add(gluon.nn.Flatten()) + net.add(gluon.nn.Dense(512, activation="relu")) + net.add(gluon.nn.Dense(10)) + return net + + +# Function to evaluate accuracy for a model +def evaluate(model, data_iter, context): + data_iter.reset() + metric = mx.metric.Accuracy() + for _, batch in enumerate(data_iter): + data = batch.data[0].as_in_context(context) + label = batch.label[0].as_in_context(context) + output = model(data.astype(args.dtype, copy=False)) + metric.update([label], [output]) + + return metric.get() + + +# Initialize Horovod +hvd.init() + +# Horovod: pin context to local rank +context = mx.gpu(hvd.local_rank()) if args.use_gpu else mx.cpu(hvd.local_rank()) +num_workers = hvd.size() + +# Load training and validation data +train_data, val_data = get_mnist_iterator(hvd.rank()) + +# Build model +model = conv_nets() +model.cast(args.dtype) +model.hybridize() + +# Define hyper parameters +optimizer_params = {'momentum': args.momentum, + 'learning_rate': args.lr * hvd.size(), + 'rescale_grad': 1.0 / args.batch_size} + +# Add Horovod Distributed Optimizer +opt = mx.optimizer.create('sgd', **optimizer_params) +opt = hvd.DistributedOptimizer(opt) + +# Initialize parameters +initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", + magnitude=2) +model.initialize(initializer, ctx=context) + +# Fetch and broadcast parameters +params = model.collect_params() +if params is not None: + hvd.broadcast_parameters(params, root_rank=0) + +# Create trainer, loss function and train metric +trainer = gluon.Trainer(params, opt, kvstore=None) +loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() +metric = mx.metric.Accuracy() + +# Train model +for epoch in range(args.epochs): + tic = time.time() + train_data.reset() + metric.reset() + for nbatch, batch in enumerate(train_data, start=1): + data = batch.data[0].as_in_context(context) + label = batch.label[0].as_in_context(context) + with autograd.record(): + output = model(data.astype(args.dtype, copy=False)) + loss = loss_fn(output, label) + loss.backward() + trainer.step(args.batch_size) + metric.update([label], [output]) + + if nbatch % 100 == 0: + name, acc = metric.get() + logging.info('[Epoch %d Batch %d] Training: %s=%f' % + (epoch, nbatch, name, acc)) + + if hvd.rank() == 0: + elapsed = time.time() - tic + speed = nbatch * args.batch_size * hvd.size() / elapsed + logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f', + epoch, speed, elapsed) + + # Evaluate model accuracy + _, train_acc = metric.get() + name, val_acc = evaluate(model, val_data, context) + if hvd.rank() == 0: + logging.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch, name, + train_acc, name, val_acc) + + if hvd.rank() == 0 and epoch == args.epochs - 1: + assert val_acc > 0.96, "Achieved accuracy (%f) is lower than expected\ + (0.96)" % val_acc diff --git a/example/distributed_training-horovod/module_mnist.py b/example/distributed_training-horovod/module_mnist.py new file mode 100644 index 000000000000..5c02aaed966c --- /dev/null +++ b/example/distributed_training-horovod/module_mnist.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import logging +import os +import zipfile + +import horovod.mxnet as hvd +import mxnet as mx +from mxnet.test_utils import download + +# Training settings +parser = argparse.ArgumentParser(description='MXNet MNIST Example') +parser.add_argument('--batch-size', type=int, default=64, + help='training batch size (default: 64)') +parser.add_argument('--dtype', type=str, default='float32', + help='training data type (default: float32)') +parser.add_argument('--epochs', type=int, default=5, + help='number of training epochs (default: 5)') +parser.add_argument('--lr', type=float, default=0.05, + help='learning rate (default: 0.05)') +parser.add_argument('--momentum', type=float, default=0.5, + help='SGD momentum (default: 0.5)') +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training (default: False)') +args = parser.parse_args() + +if not args.no_cuda: + # Disable CUDA if there are no GPUs. + if not mx.test_utils.list_gpus(): + args.no_cuda = True + +logging.basicConfig(level=logging.INFO) +logging.info(args) + + +# Function to get mnist iterator given a rank +def get_mnist_iterator(rank): + data_dir = "data-%d" % rank + if not os.path.isdir(data_dir): + os.makedirs(data_dir) + zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip', + dirname=data_dir) + with zipfile.ZipFile(zip_file_path) as zf: + zf.extractall(data_dir) + + input_shape = (1, 28, 28) + batch_size = args.batch_size + + train_iter = mx.io.MNISTIter( + image="%s/train-images-idx3-ubyte" % data_dir, + label="%s/train-labels-idx1-ubyte" % data_dir, + input_shape=input_shape, + batch_size=batch_size, + shuffle=True, + flat=False, + num_parts=hvd.size(), + part_index=hvd.rank() + ) + + val_iter = mx.io.MNISTIter( + image="%s/t10k-images-idx3-ubyte" % data_dir, + label="%s/t10k-labels-idx1-ubyte" % data_dir, + input_shape=input_shape, + batch_size=batch_size, + flat=False, + num_parts=hvd.size(), + part_index=hvd.rank() + ) + + return train_iter, val_iter + +# Step 1: initialize Horovod +hvd.init() + +# Horovod: pin context to process +context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(hvd.local_rank()) + +# Step 2: load data +train_iter, val_iter = get_mnist_iterator(hvd.rank()) + + +# Step 3: define network +def conv_net(): + # placeholder for data + data = mx.sym.var('data') + # first conv layer + conv1 = mx.sym.Convolution(data=data, kernel=(5, 5), num_filter=10) + relu1 = mx.sym.Activation(data=conv1, act_type='relu') + pool1 = mx.sym.Pooling(data=relu1, pool_type='max', kernel=(2, 2), + stride=(2, 2)) + # second conv layer + conv2 = mx.sym.Convolution(data=pool1, kernel=(5, 5), num_filter=20) + relu2 = mx.sym.Activation(data=conv2, act_type='relu') + pool2 = mx.sym.Pooling(data=relu2, pool_type='max', kernel=(2, 2), + stride=(2, 2)) + # first fully connected layer + flatten = mx.sym.flatten(data=pool2) + fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=50) + relu3 = mx.sym.Activation(data=fc1, act_type='relu') + # second fully connected layer + fc2 = mx.sym.FullyConnected(data=relu3, num_hidden=10) + # softmax loss + loss = mx.sym.SoftmaxOutput(data=fc2, name='softmax') + return loss + + +# Step 4: fit the model +net = conv_net() +model = mx.mod.Module(symbol=net, context=context) +optimizer_params = {'learning_rate': args.lr * hvd.size(), + 'rescale_grad': 1.0 / args.batch_size} +opt = mx.optimizer.create('sgd', **optimizer_params) + +# Horovod: wrap optimizer with DistributedOptimizer +opt = hvd.DistributedOptimizer(opt) + +initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", + magnitude=2) +model.bind(data_shapes=train_iter.provide_data, + label_shapes=train_iter.provide_label) +model.init_params(initializer) + +# Horovod: fetch and broadcast parameters +(arg_params, aux_params) = model.get_params() +if arg_params is not None: + hvd.broadcast_parameters(arg_params, root_rank=0) +if aux_params is not None: + hvd.broadcast_parameters(aux_params, root_rank=0) +model.set_params(arg_params=arg_params, aux_params=aux_params) + +model.fit(train_iter, # train data + kvstore=None, # no kvstore + eval_data=val_iter, # validation data + optimizer=opt, # use SGD to train + eval_metric='acc', # report accuracy during training + batch_end_callback=mx.callback.Speedometer(args.batch_size), + num_epoch=args.epochs) # train for at most 10 dataset passes + +# Step 5: evaluate model accuracy +acc = mx.metric.Accuracy() +model.score(val_iter, acc) + +if hvd.rank() == 0: + print(acc) + assert acc.get()[1] > 0.96, "Achieved accuracy (%f) is lower than \ + expected (0.96)" % acc.get()[1] diff --git a/example/distributed_training-horovod/resnet50_imagenet.py b/example/distributed_training-horovod/resnet50_imagenet.py new file mode 100644 index 000000000000..9b993403a9f0 --- /dev/null +++ b/example/distributed_training-horovod/resnet50_imagenet.py @@ -0,0 +1,453 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import logging +import math +import os +import time + +from gluoncv.model_zoo import get_model +import horovod.mxnet as hvd +import mxnet as mx +import numpy as np +from mxnet import autograd, gluon, lr_scheduler +from mxnet.io import DataBatch, DataIter + + +# Training settings +parser = argparse.ArgumentParser(description='MXNet ImageNet Example', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--use-rec', action='store_true', default=False, + help='use image record iter for data input (default: False)') +parser.add_argument('--data-nthreads', type=int, default=2, + help='number of threads for data decoding (default: 2)') +parser.add_argument('--rec-train', type=str, default='', + help='the training data') +parser.add_argument('--rec-train-idx', type=str, default='', + help='the index of training data') +parser.add_argument('--rec-val', type=str, default='', + help='the validation data') +parser.add_argument('--rec-val-idx', type=str, default='', + help='the index of validation data') +parser.add_argument('--batch-size', type=int, default=128, + help='training batch size per device (default: 128)') +parser.add_argument('--dtype', type=str, default='float32', + help='data type for training (default: float32)') +parser.add_argument('--num-epochs', type=int, default=90, + help='number of training epochs (default: 90)') +parser.add_argument('--lr', type=float, default=0.05, + help='learning rate for a single GPU (default: 0.05)') +parser.add_argument('--momentum', type=float, default=0.9, + help='momentum value for optimizer (default: 0.9)') +parser.add_argument('--wd', type=float, default=0.0001, + help='weight decay rate (default: 0.0001)') +parser.add_argument('--lr-mode', type=str, default='poly', + help='learning rate scheduler mode. Options are step, \ + poly and cosine (default: poly)') +parser.add_argument('--lr-decay', type=float, default=0.1, + help='decay rate of learning rate (default: 0.1)') +parser.add_argument('--lr-decay-epoch', type=str, default='40,60', + help='epoches at which learning rate decays (default: 40,60)') +parser.add_argument('--warmup-lr', type=float, default=0.0, + help='starting warmup learning rate (default: 0.0)') +parser.add_argument('--warmup-epochs', type=int, default=10, + help='number of warmup epochs (default: 10)') +parser.add_argument('--last-gamma', action='store_true', default=False, + help='whether to init gamma of the last BN layer in \ + each bottleneck to 0 (default: False)') +parser.add_argument('--model', type=str, default='resnet50_v1', + help='type of model to use. see vision_model for options.') +parser.add_argument('--mode', type=str, default='module', + help='mode in which to train the model. options are \ + module, gluon (default: module)') +parser.add_argument('--use-pretrained', action='store_true', default=False, + help='load pretrained model weights (default: False)') +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training (default: False)') +parser.add_argument('--eval-epoch', action='store_true', default=False, + help='evaluate validation accuracy after each epoch \ + when training in module mode (default: False)') +parser.add_argument('--eval-frequency', type=int, default=0, + help='frequency of evaluating validation accuracy \ + when training with gluon mode (default: 0)') +parser.add_argument('--log-interval', type=int, default=0, + help='number of batches to wait before logging (default: 0)') +parser.add_argument('--save-frequency', type=int, default=0, + help='frequency of model saving (default: 0)') + + +args = parser.parse_args() + +logging.basicConfig(level=logging.INFO) +logging.info(args) + +# Horovod: initialize Horovod +hvd.init() +num_workers = hvd.size() +rank = hvd.rank() +local_rank = hvd.local_rank() + +num_classes = 1000 +num_training_samples = 1281167 +batch_size = args.batch_size +epoch_size = \ + int(math.ceil(int(num_training_samples // num_workers) / batch_size)) + +if args.lr_mode == 'step': + lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')] + steps = [epoch_size * x for x in lr_decay_epoch] + lr_sched = lr_scheduler.MultiFactorScheduler( + step=steps, + factor=args.lr_decay, + base_lr=(args.lr * num_workers), + warmup_steps=(args.warmup_epochs * epoch_size), + warmup_begin_lr=args.warmup_lr + ) +elif args.lr_mode == 'poly': + lr_sched = lr_scheduler.PolyScheduler( + args.num_epochs * epoch_size, + base_lr=(args.lr * num_workers), + pwr=2, + warmup_steps=(args.warmup_epochs * epoch_size), + warmup_begin_lr=args.warmup_lr + ) +elif args.lr_mode == 'cosine': + lr_sched = lr_scheduler.CosineScheduler( + args.num_epochs * epoch_size, + base_lr=(args.lr * num_workers), + warmup_steps=(args.warmup_epochs * epoch_size), + warmup_begin_lr=args.warmup_lr + ) +else: + raise ValueError('Invalid lr mode') + +# Function for reading data from record file +# For more details about data loading in MXNet, please refer to +# https://mxnet.incubator.apache.org/tutorials/basic/data.html?highlight=imagerecorditer +def get_data_rec(rec_train, rec_train_idx, rec_val, rec_val_idx, batch_size, + data_nthreads): + rec_train = os.path.expanduser(rec_train) + rec_train_idx = os.path.expanduser(rec_train_idx) + rec_val = os.path.expanduser(rec_val) + rec_val_idx = os.path.expanduser(rec_val_idx) + jitter_param = 0.4 + lighting_param = 0.1 + mean_rgb = [123.68, 116.779, 103.939] + + def batch_fn(batch, ctx): + data = batch.data[0].as_in_context(ctx) + label = batch.label[0].as_in_context(ctx) + return data, label + + train_data = mx.io.ImageRecordIter( + path_imgrec=rec_train, + path_imgidx=rec_train_idx, + preprocess_threads=data_nthreads, + shuffle=True, + batch_size=batch_size, + label_width=1, + data_shape=(3, 224, 224), + mean_r=mean_rgb[0], + mean_g=mean_rgb[1], + mean_b=mean_rgb[2], + rand_mirror=True, + rand_crop=False, + random_resized_crop=True, + max_aspect_ratio=4. / 3., + min_aspect_ratio=3. / 4., + max_random_area=1, + min_random_area=0.08, + verbose=False, + brightness=jitter_param, + saturation=jitter_param, + contrast=jitter_param, + pca_noise=lighting_param, + num_parts=num_workers, + part_index=rank, + device_id=local_rank + ) + # Kept each node to use full val data to make it easy to monitor results + val_data = mx.io.ImageRecordIter( + path_imgrec=rec_val, + path_imgidx=rec_val_idx, + preprocess_threads=data_nthreads, + shuffle=False, + batch_size=batch_size, + resize=256, + label_width=1, + rand_crop=False, + rand_mirror=False, + data_shape=(3, 224, 224), + mean_r=mean_rgb[0], + mean_g=mean_rgb[1], + mean_b=mean_rgb[2], + device_id=local_rank + ) + + return train_data, val_data, batch_fn + +# Create data iterator for synthetic data +class SyntheticDataIter(DataIter): + def __init__(self, num_classes, data_shape, max_iter, dtype, ctx): + self.batch_size = data_shape[0] + self.cur_iter = 0 + self.max_iter = max_iter + self.dtype = dtype + label = np.random.randint(0, num_classes, [self.batch_size, ]) + data = np.random.uniform(-1, 1, data_shape) + self.data = mx.nd.array(data, dtype=self.dtype, + ctx=ctx) + self.label = mx.nd.array(label, dtype=self.dtype, + ctx=ctx) + + def __iter__(self): + return self + + @property + def provide_data(self): + return [mx.io.DataDesc('data', self.data.shape, self.dtype)] + + @property + def provide_label(self): + return [mx.io.DataDesc('softmax_label', + (self.batch_size,), self.dtype)] + + def next(self): + self.cur_iter += 1 + if self.cur_iter <= self.max_iter: + return DataBatch(data=(self.data,), + label=(self.label,), + pad=0, + index=None, + provide_data=self.provide_data, + provide_label=self.provide_label) + else: + raise StopIteration + + def __next__(self): + return self.next() + + def reset(self): + self.cur_iter = 0 + +# Horovod: pin GPU to local rank +context = mx.cpu(local_rank) if args.no_cuda else mx.gpu(local_rank) + +if args.use_rec: + # Fetch training and validation data if present + train_data, val_data, batch_fn = get_data_rec(args.rec_train, + args.rec_train_idx, + args.rec_val, + args.rec_val_idx, + batch_size, + args.data_nthreads) +else: + # Otherwise use synthetic data + image_shape = (3, 224, 224) + data_shape = (batch_size,) + image_shape + train_data = SyntheticDataIter(num_classes, data_shape, epoch_size, + np.float32, context) + val_data = None + + +# Get model from GluonCV model zoo +# https://gluon-cv.mxnet.io/model_zoo/index.html +kwargs = {'ctx': context, + 'pretrained': args.use_pretrained, + 'classes': num_classes} +if args.last_gamma: + kwargs['last_gamma'] = True +net = get_model(args.model, **kwargs) +net.cast(args.dtype) + +# Create initializer +initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", + magnitude=2) + +# Create optimizer +optimizer_params = {'wd': args.wd, + 'momentum': args.momentum, + 'rescale_grad': 1.0 / batch_size, + 'lr_scheduler': lr_sched} +if args.dtype == 'float16': + optimizer_params['multi_precision'] = True +opt = mx.optimizer.create('sgd', **optimizer_params) + +# Horovod: wrap optimizer with DistributedOptimizer +opt = hvd.DistributedOptimizer(opt) + + +def train_gluon(): + def evaluate(epoch): + if not args.use_rec: + return + + val_data.reset() + acc_top1 = mx.metric.Accuracy() + acc_top5 = mx.metric.TopKAccuracy(5) + for _, batch in enumerate(val_data): + data, label = batch_fn(batch, context) + output = net(data.astype(args.dtype, copy=False)) + acc_top1.update([label], [output]) + acc_top5.update([label], [output]) + + top1_name, top1_acc = acc_top1.get() + top5_name, top5_acc = acc_top5.get() + logging.info('Epoch[%d] Rank[%d]\tValidation-%s=%f\tValidation-%s=%f', + epoch, rank, top1_name, top1_acc, top5_name, top5_acc) + + # Hybridize and initialize model + net.hybridize() + net.initialize(initializer, ctx=context) + + # Horovod: fetch and broadcast parameters + params = net.collect_params() + if params is not None: + hvd.broadcast_parameters(params, root_rank=0) + + # Create trainer, loss function and train metric + trainer = gluon.Trainer(params, opt, kvstore=None) + loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() + metric = mx.metric.Accuracy() + + # Train model + for epoch in range(args.num_epochs): + tic = time.time() + if args.use_rec: + train_data.reset() + metric.reset() + + btic = time.time() + for nbatch, batch in enumerate(train_data, start=1): + data, label = batch_fn(batch, context) + with autograd.record(): + output = net(data.astype(args.dtype, copy=False)) + loss = loss_fn(output, label) + loss.backward() + trainer.step(batch_size) + + metric.update([label], [output]) + if args.log_interval and nbatch % args.log_interval == 0: + name, acc = metric.get() + logging.info('Epoch[%d] Rank[%d] Batch[%d]\t%s=%f\tlr=%f', + epoch, rank, nbatch, name, acc, trainer.learning_rate) + if rank == 0: + batch_speed = num_workers * batch_size * args.log_interval / (time.time() - btic) + logging.info('Epoch[%d] Batch[%d]\tSpeed: %.2f samples/sec', + epoch, nbatch, batch_speed) + btic = time.time() + + # Report metrics + elapsed = time.time() - tic + _, acc = metric.get() + logging.info('Epoch[%d] Rank[%d] Batch[%d]\tTime cost=%.2f\tTrain-accuracy=%f', + epoch, rank, nbatch, elapsed, acc) + if rank == 0: + epoch_speed = num_workers * batch_size * nbatch / elapsed + logging.info('Epoch[%d]\tSpeed: %.2f samples/sec', epoch, epoch_speed) + + # Evaluate performance + if args.eval_frequency and (epoch + 1) % args.eval_frequency == 0: + evaluate(epoch) + + # Save model + if args.save_frequency and (epoch + 1) % args.save_frequency == 0: + net.export('%s-%d' % (args.model, rank), epoch=epoch) + + # Evaluate performance at the end of training + evaluate(epoch) + + +def train_module(): + # Create input symbol + data = mx.sym.var('data') + if args.dtype == 'float16': + data = mx.sym.Cast(data=data, dtype=np.float16) + net.cast(np.float16) + + # Create output symbol + out = net(data) + if args.dtype == 'float16': + out = mx.sym.Cast(data=out, dtype=np.float32) + softmax = mx.sym.SoftmaxOutput(out, name='softmax') + + # Create model + mod = mx.mod.Module(softmax, context=context) + + # Initialize parameters + if args.use_pretrained: + arg_params = {} + for x in net.collect_params().values(): + x.reset_ctx(mx.cpu()) + arg_params[x.name] = x.data() + else: + arg_params = None + aux_params = None + mod.bind(data_shapes=train_data.provide_data, + label_shapes=train_data.provide_label) + mod.init_params(initializer, arg_params=arg_params, aux_params=aux_params) + + # Horovod: fetch and broadcast parameters + (arg_params, aux_params) = mod.get_params() + if arg_params is not None: + hvd.broadcast_parameters(arg_params, root_rank=0) + if aux_params is not None: + hvd.broadcast_parameters(aux_params, root_rank=0) + mod.set_params(arg_params=arg_params, aux_params=aux_params) + + # Setup validation data and callback during training + eval_data = None + if args.eval_epoch: + eval_data = val_data + batch_callback = None + if args.log_interval > 0 and rank == 0: + batch_callback = mx.callback.Speedometer(batch_size * num_workers, + args.log_interval) + + epoch_callback = None + if args.save_frequency > 0: + epoch_callback = mx.callback.do_checkpoint( + '%s-%d' % (args.model, rank), + period=args.save_frequency) + + # Train model + mod.fit(train_data, + eval_data=eval_data, + num_epoch=args.num_epochs, + kvstore=None, + batch_end_callback=batch_callback, + epoch_end_callback=epoch_callback, + optimizer=opt) + + # Evaluate performance if not using synthetic data + if args.use_rec: + acc_top1 = mx.metric.Accuracy() + acc_top5 = mx.metric.TopKAccuracy(5) + res = mod.score(val_data, [acc_top1, acc_top5]) + for name, val in res: + logging.info('Epoch[%d] Rank[%d] Validation-%s=%f', + args.num_epochs - 1, rank, name, val) + + +if __name__ == '__main__': + if args.mode == 'module': + train_module() + elif args.mode == 'gluon': + train_gluon() + else: + raise ValueError('Invalid training mode.')