-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix mxnet 2 nightly build #2155
Changes from 10 commits
a894d05
ad01d96
a238f00
61d3c63
0b9f9c0
ab42437
2a68930
9995410
4cb4ee4
dd7011f
bddcb91
f9f898e
452ed90
63f7828
c05b718
3b37bbb
66319b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -116,7 +116,7 @@ RUN pip install "Pillow<7.0" --no-deps | |
|
||
# Install MXNet. | ||
RUN if [[ ${MXNET_PACKAGE} == "mxnet-nightly" ]]; then \ | ||
pip install --pre mxnet-cu101mkl -f https://dist.mxnet.io/python/all; \ | ||
pip install --pre mxnet-cu101 -f https://dist.mxnet.io/python/all; \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cu101 build is being discontinued as NVIDIA only supports the latest two major versions and minor versions. can the horovod project update its dependencies? |
||
else \ | ||
pip install ${MXNET_PACKAGE} ; \ | ||
fi | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import argparse | ||
import logging | ||
import os | ||
import zipfile | ||
import time | ||
|
||
import mxnet as mx | ||
import horovod.mxnet as hvd | ||
from mxnet import autograd, gluon, nd | ||
from mxnet.test_utils import download | ||
|
||
# Training settings | ||
parser = argparse.ArgumentParser(description='MXNet MNIST Example') | ||
|
||
parser.add_argument('--batch-size', type=int, default=64, | ||
help='training batch size (default: 64)') | ||
parser.add_argument('--dtype', type=str, default='float32', | ||
help='training data type (default: float32)') | ||
parser.add_argument('--epochs', type=int, default=5, | ||
help='number of training epochs (default: 5)') | ||
parser.add_argument('--lr', type=float, default=0.01, | ||
help='learning rate (default: 0.01)') | ||
parser.add_argument('--momentum', type=float, default=0.9, | ||
help='SGD momentum (default: 0.9)') | ||
parser.add_argument('--no-cuda', action='store_true', default=False, | ||
help='disable training on GPU (default: False)') | ||
args = parser.parse_args() | ||
|
||
if not args.no_cuda: | ||
# Disable CUDA if there are no GPUs. | ||
if not mx.test_utils.list_gpus(): | ||
args.no_cuda = True | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logging.info(args) | ||
|
||
|
||
# Function to get mnist iterator given a rank | ||
def get_mnist_iterator(rank): | ||
data_dir = "data-%d" % rank | ||
if not os.path.isdir(data_dir): | ||
os.makedirs(data_dir) | ||
zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip', | ||
dirname=data_dir) | ||
with zipfile.ZipFile(zip_file_path) as zf: | ||
zf.extractall(data_dir) | ||
|
||
input_shape = (1, 28, 28) | ||
batch_size = args.batch_size | ||
|
||
train_iter = mx.io.MNISTIter( | ||
image="%s/train-images-idx3-ubyte" % data_dir, | ||
label="%s/train-labels-idx1-ubyte" % data_dir, | ||
input_shape=input_shape, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
flat=False, | ||
num_parts=hvd.size(), | ||
part_index=hvd.rank() | ||
) | ||
|
||
val_iter = mx.io.MNISTIter( | ||
image="%s/t10k-images-idx3-ubyte" % data_dir, | ||
label="%s/t10k-labels-idx1-ubyte" % data_dir, | ||
input_shape=input_shape, | ||
batch_size=batch_size, | ||
flat=False, | ||
) | ||
|
||
return train_iter, val_iter | ||
|
||
|
||
# Function to define neural network | ||
def conv_nets(): | ||
net = gluon.nn.HybridSequential() | ||
net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu')) | ||
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2)) | ||
net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu')) | ||
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2)) | ||
net.add(gluon.nn.Flatten()) | ||
net.add(gluon.nn.Dense(512, activation="relu")) | ||
net.add(gluon.nn.Dense(10)) | ||
return net | ||
|
||
|
||
# Function to evaluate accuracy for a model | ||
def evaluate(model, data_iter, context): | ||
data_iter.reset() | ||
metric = mx.gluon.metric.Accuracy() | ||
for _, batch in enumerate(data_iter): | ||
data = batch.data[0].as_in_context(context) | ||
label = batch.label[0].as_in_context(context) | ||
output = model(data.astype(args.dtype, copy=False)) | ||
metric.update([label], [output]) | ||
|
||
return metric.get() | ||
|
||
|
||
# Initialize Horovod | ||
hvd.init() | ||
|
||
# Horovod: pin context to local rank | ||
context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(hvd.local_rank()) | ||
num_workers = hvd.size() | ||
|
||
# Load training and validation data | ||
train_data, val_data = get_mnist_iterator(hvd.rank()) | ||
|
||
# Build model | ||
model = conv_nets() | ||
model.cast(args.dtype) | ||
model.hybridize() | ||
|
||
# Create optimizer | ||
optimizer_params = {'momentum': args.momentum, | ||
'learning_rate': args.lr * hvd.size()} | ||
opt = mx.optimizer.create('sgd', **optimizer_params) | ||
|
||
# Initialize parameters | ||
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", | ||
magnitude=2) | ||
model.initialize(initializer, ctx=context) | ||
|
||
# Horovod: fetch and broadcast parameters | ||
params = model.collect_params() | ||
if params is not None: | ||
hvd.broadcast_parameters(params, root_rank=0) | ||
|
||
# Horovod: create DistributedTrainer, a subclass of gluon.Trainer | ||
trainer = hvd.DistributedTrainer(params, opt) | ||
|
||
# Create loss function and train metric | ||
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() | ||
metric = mx.gluon.metric.Accuracy() | ||
|
||
# Train model | ||
for epoch in range(args.epochs): | ||
tic = time.time() | ||
train_data.reset() | ||
metric.reset() | ||
for nbatch, batch in enumerate(train_data, start=1): | ||
data = batch.data[0].as_in_context(context) | ||
label = batch.label[0].as_in_context(context) | ||
with autograd.record(): | ||
output = model(data.astype(args.dtype, copy=False)) | ||
loss = loss_fn(output, label) | ||
loss.backward() | ||
trainer.step(args.batch_size) | ||
metric.update([label], [output]) | ||
|
||
if nbatch % 100 == 0: | ||
name, acc = metric.get() | ||
logging.info('[Epoch %d Batch %d] Training: %s=%f' % | ||
(epoch, nbatch, name, acc)) | ||
|
||
if hvd.rank() == 0: | ||
elapsed = time.time() - tic | ||
speed = nbatch * args.batch_size * hvd.size() / elapsed | ||
logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f', | ||
epoch, speed, elapsed) | ||
|
||
# Evaluate model accuracy | ||
_, train_acc = metric.get() | ||
name, val_acc = evaluate(model, val_data, context) | ||
if hvd.rank() == 0: | ||
logging.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch, name, | ||
train_acc, name, val_acc) | ||
|
||
if hvd.rank() == 0 and epoch == args.epochs - 1: | ||
assert val_acc > 0.96, "Achieved accuracy (%f) is lower than expected\ | ||
(0.96)" % val_acc |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -100,18 +100,21 @@ def __init__(self, params, optimizer, optimizer_params=None): | |
def _allreduce_grads(self): | ||
if size() == 1: return | ||
|
||
# In MXNet 2.0, param.name is no longer unique. | ||
# Meanwhile, since horovod requires Python 3.6, there is no need to sort | ||
# self._params as enumerating a python dict is always deterministic. | ||
for i, param in enumerate(self._params): | ||
if param.grad_req != 'null': | ||
allreduce_(param.list_grad()[0], average=False, | ||
name=param.name, priority=-i) | ||
name=str(i), priority=-i) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you elaborate more on the I see in your changes to Maybe one option would be to add an optional There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the review. As part of this PR https://github.com/apache/incubator-mxnet/pull/18619/files, the name scope class is removed to avoid the usage of thread local object in the python front-end. The parameters are now distinguished by their uuid instead ( You brought up a valid point on the use case where multiple DistributedTrainers. Adding a name parameter to distributed trainer sounds reasonable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm adding a parameter -> name mapping in the constructor ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @leezu to suggest on the choice of unique names for parameters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The names in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @romerojosh @eric-haibin-lin what's the uniqueness and consistency requirement for this identifier? Does it need to be consistent across different workers? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact after checking with @leezu, the following two blocks will have the same name for parameters in mxnet2:
So for horovod we probably need to go back to the approach where we add a |
||
|
||
|
||
# Wrapper to inject Horovod broadcast after parameter initialization | ||
def _append_broadcast_init(param, root_rank): | ||
def _append_broadcast_init(param, root_rank, name): | ||
init_impl = getattr(param, '_init_impl') | ||
def wrapped_init_impl(self, *args, **kwargs): | ||
init_impl(*args, **kwargs) | ||
broadcast_(self.data(), root_rank=root_rank, name=self.name) | ||
broadcast_(self.data(), root_rank=root_rank, name=name) | ||
return wrapped_init_impl | ||
|
||
|
||
|
@@ -132,17 +135,25 @@ def broadcast_parameters(params, root_rank=0): | |
|
||
tensors = [] | ||
names = [] | ||
if isinstance(params, dict): | ||
names, tensors = zip(*params.items()) | ||
elif isinstance(params, mx.gluon.parameter.ParameterDict): | ||
try: | ||
from mxnet.gluon.parameter import ParameterDict | ||
valid_types = (dict, ParameterDict) | ||
except ImportError: | ||
valid_types = (dict,) | ||
if isinstance(params, valid_types): | ||
for name, p in sorted(params.items()): | ||
try: | ||
tensors.append(p.data()) | ||
if isinstance(p, mx.gluon.parameter.Parameter): | ||
tensors.append(p.data()) | ||
else: | ||
tensors.append(p) | ||
names.append(name) | ||
except mx.gluon.parameter.DeferredInitializationError: | ||
# Inject wrapper method with post-initialization broadcast to | ||
# handle parameters with deferred initialization | ||
new_init = _append_broadcast_init(p, root_rank) | ||
# we use the key of params instead of param.name, since | ||
# param.name is no longer unique in MXNet 2.0 | ||
new_init = _append_broadcast_init(p, root_rank, name) | ||
p._init_impl = types.MethodType(new_init, p) | ||
else: | ||
raise ValueError('invalid params of type: %s' % type(params)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there should be a separate pipeline for 1.x