Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

example: Add MXNet Gluon training example of MNIST. #22

Merged
merged 7 commits into from
Jul 3, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions byteps/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import print_function

import threading
import warnings

from byteps.mxnet.ops import byteps_push_pull, byteps_declare_tensor
from byteps.mxnet.ops import init, shutdown
Expand Down Expand Up @@ -123,3 +124,47 @@ def broadcast_parameters(params, root_rank=0):
tensor.wait_to_read()


class DistributedTrainer(mx.gluon.Trainer):
"""A subclass of MXNet gluon.Trainer.

There are two differences between DistributedTrainer and Trainer:
1. DistributedTrainer calculates gradients using Byteps push pull
API while Trainer does it using kvstore push/pull APIs;
2. DistributedTrainer performs allreduce(summation) and average
while Trainer only performs allreduce(summation).

Parameters
----------
params : ParameterDict
The set of parameters to optimize.
optimizer : str or Optimizer
The optimizer to use. See
`help <http://mxnet.io/api/python/optimization/optimization.html#the-mxnet-optimizer-package>`_
on Optimizer for a list of available optimizers.
optimizer_params : dict
Key-word arguments to be passed to optimizer constructor. For example,
`{'learning_rate': 0.1}`. All optimizers accept learning_rate, wd (weight decay),
clip_gradient, and lr_scheduler. See each optimizer's
constructor for a list of additional supported arguments.
"""

def __init__(self, params, optimizer, optimizer_params=None):
if isinstance(optimizer, DistributedOptimizer):
optimizer = optimizer._optimizer
warnings.warn("DistributedTrainer does not take DistributedOptimizer "
"as its optimizer. We have unwrapped it for you.")

super(DistributedTrainer, self).__init__(
params, optimizer, optimizer_params=optimizer_params, kvstore=None)

# _scale is used to check and set rescale_grad for optimizer in Trainer.step()
# function. Normalizing it by byteps size, which is equivalent to performing
# average in allreduce, has better performance.
self._scale /= size()

def _allreduce_grads(self):
for i, param in enumerate(self._params):
if param.grad_req != 'null':
byteps_push_pull(param.list_grad()[0], is_average=False,
name="parameter_"+str(i), priority=-i)

17 changes: 17 additions & 0 deletions example/mxnet-gluon/run_mnist_gluon.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash

export NVIDIA_VISIBLE_DEVICES=0,1
export DMLC_WORKER_ID=0
export DMLC_NUM_WORKER=1
export DMLC_ROLE=worker

# the following value does not matter for non-distributed jobs
export DMLC_NUM_SERVER=1
export DMLC_PS_ROOT_URI=127.0.0.1
export DMLC_PS_ROOT_PORT=9000

path="`dirname $0`"
echo $path

python $path/../../launcher/launch.py \
python $path/train_mnist_byteps.py
169 changes: 169 additions & 0 deletions example/mxnet-gluon/train_mnist_byteps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright 2019 Bytedance Inc. or its affiliates. All Rights Reserved.
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file is modified from `horovod/examples/mxnet_mnist.py`, using gluon style MNIST dataset and data_loader."""
import time

import argparse
import logging

import mxnet as mx
import byteps.mxnet as bps
from mxnet import autograd, gluon, nd
from mxnet.gluon.data.vision import MNIST


# Higher download speed for chinese users
# os.environ['MXNET_GLUON_REPO'] = 'https://apache-mxnet.s3.cn-north-1.amazonaws.com.cn/'

# Training settings
parser = argparse.ArgumentParser(description='MXNet MNIST Example')

parser.add_argument('--batch-size', type=int, default=64,
help='training batch size (default: 64)')
parser.add_argument('--dtype', type=str, default='float32',
help='training data type (default: float32)')
parser.add_argument('--epochs', type=int, default=5,
help='number of training epochs (default: 5)')
parser.add_argument('--j', type=int, default=2,
help='number of cpu processes for dataloader')
parser.add_argument('--lr', type=float, default=0.01,
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum (default: 0.9)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disable training on GPU (default: False)')
args = parser.parse_args()

if not args.no_cuda:
# Disable CUDA if there are no GPUs.
if mx.context.num_gpus() == 0:
args.no_cuda = True

logging.basicConfig(level=logging.INFO)
logging.info(args)


def dummy_transform(data, label):
im = data.astype(args.dtype, copy=False) / 255 - 0.5
im = nd.transpose(im, (2, 0, 1))
return im, label


# Function to get mnist iterator
def get_mnist_iterator():
train_set = MNIST(train=True, transform=dummy_transform)
train_iter = gluon.data.DataLoader(train_set, args.batch_size, True, num_workers=args.j, last_batch='discard')
val_set = MNIST(train=False, transform=dummy_transform)
val_iter = gluon.data.DataLoader(val_set, args.batch_size, False, num_workers=0)

return train_iter, val_iter, len(train_set)


# Function to define neural network
def conv_nets():
net = gluon.nn.HybridSequential()
with net.name_scope():
net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
net.add(gluon.nn.Flatten())
net.add(gluon.nn.Dense(512, activation="relu"))
net.add(gluon.nn.Dense(10))
return net


# Function to evaluate accuracy for a model
def evaluate(model, data_iter, context):
metric = mx.metric.Accuracy()
for _, batch in enumerate(data_iter):
data = batch[0].as_in_context(context)
label = batch[1].as_in_context(context)
output = model(data.astype(args.dtype, copy=False))
metric.update([label], [output])

return metric.get()


# Initialize BytePS
bps.init()

# BytePS: pin context to local rank
context = mx.cpu(bps.local_rank()) if args.no_cuda else mx.gpu(bps.local_rank())
num_workers = bps.size()

# Load training and validation data
train_data, val_data, train_size = get_mnist_iterator()

# Build model
model = conv_nets()
model.cast(args.dtype)

# Initialize parameters
model.initialize(mx.init.MSRAPrelu(), ctx=context)
# if bps.rank() == 0:
model.summary(nd.ones((1, 1, 28, 28), ctx=mx.gpu(bps.local_rank())))
model.hybridize()

# BytePS: fetch and broadcast parameters
params = model.collect_params()
if params is not None:
bps.broadcast_parameters(params, root_rank=0)

# BytePS: create DistributedTrainer, a subclass of gluon.Trainer
optimizer_params = {'momentum': args.momentum, 'learning_rate': args.lr * num_workers}
trainer = bps.DistributedTrainer(params, "sgd", optimizer_params)

# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
metric = mx.metric.Accuracy()

# Train model
for epoch in range(args.epochs):
tic = time.time()
metric.reset()
for i, batch in enumerate(train_data):
data = batch[0].as_in_context(context)
label = batch[1].as_in_context(context)

with autograd.record():
output = model(data)
loss = loss_fn(output, label)

loss.backward()
trainer.step(args.batch_size)
metric.update([label], [output])

if i % 100 == 0:
name, acc = metric.get()
logging.info('[Epoch %d Batch %d] Training: %s=%f' %
(epoch, i, name, acc))

if bps.rank() == 0:
elapsed = time.time() - tic
speed = train_size * num_workers / elapsed
logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f',
epoch, speed, elapsed)

# Evaluate model accuracy
_, train_acc = metric.get()
name, val_acc = evaluate(model, val_data, context)
if bps.rank() == 0:
logging.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch, name,
train_acc, name, val_acc)

if bps.rank() == 0 and epoch == args.epochs - 1:
assert val_acc > 0.96, "Achieved accuracy (%f) is lower than expected\
(0.96)" % val_acc