diff --git a/byteps/mxnet/__init__.py b/byteps/mxnet/__init__.py index 4f053aea7..b269248a7 100644 --- a/byteps/mxnet/__init__.py +++ b/byteps/mxnet/__init__.py @@ -19,6 +19,7 @@ from __future__ import print_function import threading +import warnings from byteps.mxnet.ops import byteps_push_pull, byteps_declare_tensor from byteps.mxnet.ops import init, shutdown @@ -123,3 +124,47 @@ def broadcast_parameters(params, root_rank=0): tensor.wait_to_read() +class DistributedTrainer(mx.gluon.Trainer): + """A subclass of MXNet gluon.Trainer. + + There are two differences between DistributedTrainer and Trainer: + 1. DistributedTrainer calculates gradients using BytePS push pull + API while Trainer does it using kvstore push/pull APIs; + 2. DistributedTrainer performs push_pull(summation) and average, + while Trainer only performs push_pull(summation). + + Parameters + ---------- + params : ParameterDict + The set of parameters to optimize. + optimizer : str or Optimizer + The optimizer to use. See + `help `_ + on Optimizer for a list of available optimizers. + optimizer_params : dict + Key-word arguments to be passed to optimizer constructor. For example, + `{'learning_rate': 0.1}`. All optimizers accept learning_rate, wd (weight decay), + clip_gradient, and lr_scheduler. See each optimizer's + constructor for a list of additional supported arguments. + """ + + def __init__(self, params, optimizer, optimizer_params=None): + if isinstance(optimizer, DistributedOptimizer): + optimizer = optimizer._optimizer + warnings.warn("DistributedTrainer does not take DistributedOptimizer " + "as its optimizer. We have unwrapped it for you.") + + super(DistributedTrainer, self).__init__( + params, optimizer, optimizer_params=optimizer_params, kvstore=None) + + # _scale is used to check and set rescale_grad for optimizer in Trainer.step() + # function. Normalizing it by BytePS size, which is equivalent to performing + # average in push_pull, has better performance. + self._scale /= size() + + def _allreduce_grads(self): + for i, param in enumerate(self._params): + if param.grad_req != 'null': + byteps_push_pull(param.list_grad()[0], is_average=False, + name="parameter_"+str(i), priority=-i) + diff --git a/example/mxnet-gluon/run_mnist_gluon.sh b/example/mxnet-gluon/run_mnist_gluon.sh new file mode 100644 index 000000000..2632e358c --- /dev/null +++ b/example/mxnet-gluon/run_mnist_gluon.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +export NVIDIA_VISIBLE_DEVICES=0,1 +export DMLC_WORKER_ID=0 +export DMLC_NUM_WORKER=1 +export DMLC_ROLE=worker + +# the following value does not matter for non-distributed jobs +export DMLC_NUM_SERVER=1 +export DMLC_PS_ROOT_URI=127.0.0.1 +export DMLC_PS_ROOT_PORT=9000 + +path="`dirname $0`" +echo $path + +python $path/../../launcher/launch.py \ + python $path/train_mnist_byteps.py \ No newline at end of file diff --git a/example/mxnet-gluon/train_mnist_byteps.py b/example/mxnet-gluon/train_mnist_byteps.py new file mode 100644 index 000000000..2e2e0fbdf --- /dev/null +++ b/example/mxnet-gluon/train_mnist_byteps.py @@ -0,0 +1,169 @@ +# Copyright 2019 Bytedance Inc. or its affiliates. All Rights Reserved. +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This file is modified from `horovod/examples/mxnet_mnist.py`, using gluon style MNIST dataset and data_loader.""" +import time + +import argparse +import logging + +import mxnet as mx +import byteps.mxnet as bps +from mxnet import autograd, gluon, nd +from mxnet.gluon.data.vision import MNIST + + +# Higher download speed for chinese users +# os.environ['MXNET_GLUON_REPO'] = 'https://apache-mxnet.s3.cn-north-1.amazonaws.com.cn/' + +# Training settings +parser = argparse.ArgumentParser(description='MXNet MNIST Example') + +parser.add_argument('--batch-size', type=int, default=64, + help='training batch size (default: 64)') +parser.add_argument('--dtype', type=str, default='float32', + help='training data type (default: float32)') +parser.add_argument('--epochs', type=int, default=5, + help='number of training epochs (default: 5)') +parser.add_argument('--j', type=int, default=2, + help='number of cpu processes for dataloader') +parser.add_argument('--lr', type=float, default=0.01, + help='learning rate (default: 0.01)') +parser.add_argument('--momentum', type=float, default=0.9, + help='SGD momentum (default: 0.9)') +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disable training on GPU (default: False)') +args = parser.parse_args() + +if not args.no_cuda: + # Disable CUDA if there are no GPUs. + if mx.context.num_gpus() == 0: + args.no_cuda = True + +logging.basicConfig(level=logging.INFO) +logging.info(args) + + +def dummy_transform(data, label): + im = data.astype(args.dtype, copy=False) / 255 - 0.5 + im = nd.transpose(im, (2, 0, 1)) + return im, label + + +# Function to get mnist iterator +def get_mnist_iterator(): + train_set = MNIST(train=True, transform=dummy_transform) + train_iter = gluon.data.DataLoader(train_set, args.batch_size, True, num_workers=args.j, last_batch='discard') + val_set = MNIST(train=False, transform=dummy_transform) + val_iter = gluon.data.DataLoader(val_set, args.batch_size, False, num_workers=0) + + return train_iter, val_iter, len(train_set) + + +# Function to define neural network +def conv_nets(): + net = gluon.nn.HybridSequential() + with net.name_scope(): + net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu')) + net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2)) + net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu')) + net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2)) + net.add(gluon.nn.Flatten()) + net.add(gluon.nn.Dense(512, activation="relu")) + net.add(gluon.nn.Dense(10)) + return net + + +# Function to evaluate accuracy for a model +def evaluate(model, data_iter, context): + metric = mx.metric.Accuracy() + for _, batch in enumerate(data_iter): + data = batch[0].as_in_context(context) + label = batch[1].as_in_context(context) + output = model(data.astype(args.dtype, copy=False)) + metric.update([label], [output]) + + return metric.get() + + +# Initialize BytePS +bps.init() + +# BytePS: pin context to local rank +context = mx.cpu(bps.local_rank()) if args.no_cuda else mx.gpu(bps.local_rank()) +num_workers = bps.size() + +# Load training and validation data +train_data, val_data, train_size = get_mnist_iterator() + +# Build model +model = conv_nets() +model.cast(args.dtype) + +# Initialize parameters +model.initialize(mx.init.MSRAPrelu(), ctx=context) +# if bps.rank() == 0: +model.summary(nd.ones((1, 1, 28, 28), ctx=mx.gpu(bps.local_rank()))) +model.hybridize() + +# BytePS: fetch and broadcast parameters +params = model.collect_params() +if params is not None: + bps.broadcast_parameters(params, root_rank=0) + +# BytePS: create DistributedTrainer, a subclass of gluon.Trainer +optimizer_params = {'momentum': args.momentum, 'learning_rate': args.lr * num_workers} +trainer = bps.DistributedTrainer(params, "sgd", optimizer_params) + +# Create loss function and train metric +loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() +metric = mx.metric.Accuracy() + +# Train model +for epoch in range(args.epochs): + tic = time.time() + metric.reset() + for i, batch in enumerate(train_data): + data = batch[0].as_in_context(context) + label = batch[1].as_in_context(context) + + with autograd.record(): + output = model(data) + loss = loss_fn(output, label) + + loss.backward() + trainer.step(args.batch_size) + metric.update([label], [output]) + + if i % 100 == 0: + name, acc = metric.get() + logging.info('[Epoch %d Batch %d] Training: %s=%f' % + (epoch, i, name, acc)) + + if bps.rank() == 0: + elapsed = time.time() - tic + speed = train_size * num_workers / elapsed + logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f', + epoch, speed, elapsed) + + # Evaluate model accuracy + _, train_acc = metric.get() + name, val_acc = evaluate(model, val_data, context) + if bps.rank() == 0: + logging.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch, name, + train_acc, name, val_acc) + + if bps.rank() == 0 and epoch == args.epochs - 1: + assert val_acc > 0.96, "Achieved accuracy (%f) is lower than expected\ + (0.96)" % val_acc