Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mxnet 2 nightly build #2155

Closed
wants to merge 17 commits into from
36 changes: 27 additions & 9 deletions .buildkite/gen-pipeline.sh
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,15 @@ run_mpi_integration() {
"bash -c \"${oneccl_env} \\\$(cat /mpirun_command) python /horovod/examples/pytorch_mnist.py\""
fi

run_test "${test}" "${queue}" \
":muscle: Test MXNet MNIST (${test})" \
"bash -c \"${oneccl_env} OMP_NUM_THREADS=1 \\\$(cat /mpirun_command) python /horovod/examples/mxnet_mnist.py\""
if [[ ${test} == *"mxnet2_"* ]] || [[ ${test} == *"mxnethead"* ]]; then
run_test "${test}" "${queue}" \
":muscle: Test MXNet2 MNIST (${test})" \
"bash -c \"${oneccl_env} OMP_NUM_THREADS=1 \\\$(cat /mpirun_command) python /horovod/examples/mxnet2_mnist.py\""
else
run_test "${test}" "${queue}" \
":muscle: Test MXNet MNIST (${test})" \
"bash -c \"${oneccl_env} OMP_NUM_THREADS=1 \\\$(cat /mpirun_command) python /horovod/examples/mxnet_mnist.py\""
fi

# tests that should be executed only with the latest release since they don't test
# a framework-specific functionality
Expand Down Expand Up @@ -249,9 +255,15 @@ run_gloo_integration() {
":fire: Test PyTorch MNIST (${test})" \
"horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/pytorch_mnist.py"

run_test "${test}" "${queue}" \
":muscle: Test MXNet MNIST (${test})" \
"horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/mxnet_mnist.py"
if [[ ${test} == *"mxnet2_"* ]] || [[ ${test} == *"mxnethead"* ]]; then
run_test "${test}" "${queue}" \
":muscle: Test MXNet2 MNIST (${test})" \
"horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/mxnet2_mnist.py"
else
run_test "${test}" "${queue}" \
":muscle: Test MXNet MNIST (${test})" \
"horovodrun -np 2 -H localhost:2 --gloo python /horovod/examples/mxnet_mnist.py"
fi

# Elastic
local elastic_tensorflow="test_elastic_tensorflow.py"
Expand Down Expand Up @@ -334,9 +346,15 @@ run_single_integration() {
":fire: Single PyTorch MNIST (${test})" \
"bash -c \"${oneccl_env} python /horovod/examples/pytorch_mnist.py --epochs 3\""

run_test "${test}" "${queue}" \
":muscle: Single MXNet MNIST (${test})" \
"bash -c \"${oneccl_env} python /horovod/examples/mxnet_mnist.py --epochs 3\""
if [[ ${test} == *"mxnet2_"* ]] || [[ ${test} == *"mxnethead"* ]]; then
run_test "${test}" "${queue}" \
":muscle: Single MXNet2 MNIST (${test})" \
"bash -c \"${oneccl_env} python /horovod/examples/mxnet2_mnist.py --epochs 3\""
else
run_test "${test}" "${queue}" \
":muscle: Single MXNet MNIST (${test})" \
"bash -c \"${oneccl_env} python /horovod/examples/mxnet_mnist.py --epochs 3\""
fi
}

build_docs() {
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.test.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ RUN pip install "Pillow<7.0" --no-deps

# Install MXNet.
RUN if [[ ${MXNET_PACKAGE} == "mxnet-nightly" ]]; then \
pip install --pre mxnet-mkl -f https://dist.mxnet.io/python/all; \
pip install --pre mxnet -f https://dist.mxnet.io/python/all; \
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there should be a separate pipeline for 1.x

else \
pip install ${MXNET_PACKAGE} ; \
fi
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile.test.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ RUN pip install "Pillow<7.0" --no-deps

# Install MXNet.
RUN if [[ ${MXNET_PACKAGE} == "mxnet-nightly" ]]; then \
pip install --pre mxnet-cu101mkl -f https://dist.mxnet.io/python/all; \
pip install --pre mxnet-cu101 -f https://dist.mxnet.io/python/all; \
Copy link

@szha szha Aug 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cu101 build is being discontinued as NVIDIA only supports the latest two major versions and minor versions. can the horovod project update its dependencies?

else \
pip install ${MXNET_PACKAGE} ; \
fi
Expand Down
171 changes: 171 additions & 0 deletions examples/mxnet2_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import argparse
import logging
import os
import zipfile
import time

import mxnet as mx
import horovod.mxnet as hvd
from mxnet import autograd, gluon, nd
from mxnet.test_utils import download

# Training settings
parser = argparse.ArgumentParser(description='MXNet MNIST Example')

parser.add_argument('--batch-size', type=int, default=64,
help='training batch size (default: 64)')
parser.add_argument('--dtype', type=str, default='float32',
help='training data type (default: float32)')
parser.add_argument('--epochs', type=int, default=5,
help='number of training epochs (default: 5)')
parser.add_argument('--lr', type=float, default=0.01,
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum (default: 0.9)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disable training on GPU (default: False)')
args = parser.parse_args()

if not args.no_cuda:
# Disable CUDA if there are no GPUs.
if not mx.test_utils.list_gpus():
args.no_cuda = True

logging.basicConfig(level=logging.INFO)
logging.info(args)


# Function to get mnist iterator given a rank
def get_mnist_iterator(rank):
data_dir = "data-%d" % rank
if not os.path.isdir(data_dir):
os.makedirs(data_dir)
zip_file_path = download('http://data.mxnet.io/mxnet/data/mnist.zip',
dirname=data_dir)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(data_dir)

input_shape = (1, 28, 28)
batch_size = args.batch_size

train_iter = mx.io.MNISTIter(
image="%s/train-images-idx3-ubyte" % data_dir,
label="%s/train-labels-idx1-ubyte" % data_dir,
input_shape=input_shape,
batch_size=batch_size,
shuffle=True,
flat=False,
num_parts=hvd.size(),
part_index=hvd.rank()
)

val_iter = mx.io.MNISTIter(
image="%s/t10k-images-idx3-ubyte" % data_dir,
label="%s/t10k-labels-idx1-ubyte" % data_dir,
input_shape=input_shape,
batch_size=batch_size,
flat=False,
)

return train_iter, val_iter


# Function to define neural network
def conv_nets():
net = gluon.nn.HybridSequential()
net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
net.add(gluon.nn.Flatten())
net.add(gluon.nn.Dense(512, activation="relu"))
net.add(gluon.nn.Dense(10))
return net


# Function to evaluate accuracy for a model
def evaluate(model, data_iter, context):
data_iter.reset()
metric = mx.gluon.metric.Accuracy()
for _, batch in enumerate(data_iter):
data = batch.data[0].as_in_context(context)
label = batch.label[0].as_in_context(context)
output = model(data.astype(args.dtype, copy=False))
metric.update([label], [output])

return metric.get()


# Initialize Horovod
hvd.init()

# Horovod: pin context to local rank
context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(hvd.local_rank())
num_workers = hvd.size()

# Load training and validation data
train_data, val_data = get_mnist_iterator(hvd.rank())

# Build model
model = conv_nets()
model.cast(args.dtype)
model.hybridize()

# Create optimizer
optimizer_params = {'momentum': args.momentum,
'learning_rate': args.lr * hvd.size()}
opt = mx.optimizer.create('sgd', **optimizer_params)

# Initialize parameters
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in",
magnitude=2)
model.initialize(initializer, ctx=context)

# Horovod: fetch and broadcast parameters
params = model.collect_params()
if params is not None:
hvd.broadcast_parameters(params, root_rank=0)

# Horovod: create DistributedTrainer, a subclass of gluon.Trainer
trainer = hvd.DistributedTrainer(params, opt)

# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
metric = mx.gluon.metric.Accuracy()

# Train model
for epoch in range(args.epochs):
tic = time.time()
train_data.reset()
metric.reset()
for nbatch, batch in enumerate(train_data, start=1):
data = batch.data[0].as_in_context(context)
label = batch.label[0].as_in_context(context)
with autograd.record():
output = model(data.astype(args.dtype, copy=False))
loss = loss_fn(output, label)
loss.backward()
trainer.step(args.batch_size)
metric.update([label], [output])

if nbatch % 100 == 0:
name, acc = metric.get()
logging.info('[Epoch %d Batch %d] Training: %s=%f' %
(epoch, nbatch, name, acc))

if hvd.rank() == 0:
elapsed = time.time() - tic
speed = nbatch * args.batch_size * hvd.size() / elapsed
logging.info('Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f',
epoch, speed, elapsed)

# Evaluate model accuracy
_, train_acc = metric.get()
name, val_acc = evaluate(model, val_data, context)
if hvd.rank() == 0:
logging.info('Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f', epoch, name,
train_acc, name, val_acc)

if hvd.rank() == 0 and epoch == args.epochs - 1:
assert val_acc > 0.96, "Achieved accuracy (%f) is lower than expected\
(0.96)" % val_acc
27 changes: 19 additions & 8 deletions horovod/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,21 @@ def __init__(self, params, optimizer, optimizer_params=None):
def _allreduce_grads(self):
if size() == 1: return

# In MXNet 2.0, param.name is no longer unique.
# Meanwhile, since horovod requires Python 3.6, there is no need to sort
# self._params as enumerating a python dict is always deterministic.
for i, param in enumerate(self._params):
if param.grad_req != 'null':
allreduce_(param.list_grad()[0], average=False,
name=param.name, priority=-i)
name=str(i), priority=-i)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate more on the param.name being no longer unique in MXNet 2.0? I'd hesitate to revert back to using str(i) as the allreduce request names, as it would reintroduce the problem described in #1679. Basically, in cases where users may have multiple DistributedTrainers, this naming scheme will not differentiate between gradients being submitted by the different optimizers, instead producing multiple requests with the same name (i.e. multiple allreduce.0, allreduce.1, etc.).

I see in your changes to broadcast_parameters you use the dictionary key values, but I guess that isn't an option here because self._params is just a list of params, rather than the original dictionary.

Maybe one option would be to add an optional name argument for the DistributedTrainer that gets prepended to the names of the allreduce operations being submitted? That way, users can provide unique trainer names if they need to disambiguate allreduce operations from multiple trainers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review. As part of this PR https://github.com/apache/incubator-mxnet/pull/18619/files, the name scope class is removed to avoid the usage of thread local object in the python front-end. The parameters are now distinguished by their uuid instead (idx = self._param2idx[parameter._uuid]) and the parameter name saved in the Parameter class can be identical.

You brought up a valid point on the use case where multiple DistributedTrainers. Adding a name parameter to distributed trainer sounds reasonable.

Copy link
Collaborator Author

@eric-haibin-lin eric-haibin-lin Aug 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm adding a parameter -> name mapping in the constructor (self._param2name) in apache/mxnet#18877. Maybe we can use that, too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using name = self._param2name[p._uuid] after apache/mxnet#18877 lands might be a cleaner option than adding a global name to the DistributedTrainer. Aligns more closely to the existing code and would be less disruptive to users who might be using multiple optimizers already. What do you think?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @leezu to suggest on the choice of unique names for parameters.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The names in self._param2name[p._uuid] will not be globally unique if the user creates multiple Blocks with the same structure but different parameters. Actually I'm not convinced we should add self._param2name[p._uuid]

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@romerojosh @eric-haibin-lin what's the uniqueness and consistency requirement for this identifier? Does it need to be consistent across different workers?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of the allreduce operation needs to be consistent across all workers participating in the communication. This name identifier is what is used to match up tensors in the Horovod backend. Concerning uniqueness, names can be reused; however, all operations in flight must have unique names (i.e. a name can only be reused if the previous Horovod operation using that name has completed).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact after checking with @leezu, the following two blocks will have the same name for parameters in mxnet2:

net1 = nn.Dense(10)
net2 = nn.Dense(10)

assert net1.collect_params().keys() == net2.collect_params().keys()

So for horovod we probably need to go back to the approach where we add a name argument to DistributedTrainer.



# Wrapper to inject Horovod broadcast after parameter initialization
def _append_broadcast_init(param, root_rank):
def _append_broadcast_init(param, root_rank, name):
init_impl = getattr(param, '_init_impl')
def wrapped_init_impl(self, *args, **kwargs):
init_impl(*args, **kwargs)
broadcast_(self.data(), root_rank=root_rank, name=self.name)
broadcast_(self.data(), root_rank=root_rank, name=name)
return wrapped_init_impl


Expand All @@ -132,17 +135,25 @@ def broadcast_parameters(params, root_rank=0):

tensors = []
names = []
if isinstance(params, dict):
names, tensors = zip(*params.items())
elif isinstance(params, mx.gluon.parameter.ParameterDict):
try:
from mxnet.gluon.parameter import ParameterDict
valid_types = (dict, ParameterDict)
except ImportError:
valid_types = (dict,)
if isinstance(params, valid_types):
for name, p in sorted(params.items()):
try:
tensors.append(p.data())
if isinstance(p, mx.gluon.parameter.Parameter):
tensors.append(p.data())
else:
tensors.append(p)
names.append(name)
except mx.gluon.parameter.DeferredInitializationError:
# Inject wrapper method with post-initialization broadcast to
# handle parameters with deferred initialization
new_init = _append_broadcast_init(p, root_rank)
# we use the key of params instead of param.name, since
# param.name is no longer unique in MXNet 2.0
new_init = _append_broadcast_init(p, root_rank, name)
p._init_impl = types.MethodType(new_init, p)
else:
raise ValueError('invalid params of type: %s' % type(params))
Expand Down
8 changes: 8 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,14 @@ def build_mx_extension(build_ext, global_options):
mxnet_mpi_lib.define_macros += [('MXNET_USE_MKLDNN', '1')]
else:
mxnet_mpi_lib.define_macros += [('MXNET_USE_MKLDNN', '0')]
cxx11_abi = '1'
try:
import mxnet as mx
if int(mx.library.compiled_with_gcc_cxx11_abi()) == 0:
cxx11_abi = '0'
except AttributeError:
pass
mxnet_mpi_lib.define_macros += [('_GLIBCXX_USE_CXX11_ABI', cxx11_abi)]
mxnet_mpi_lib.define_macros += [('MSHADOW_USE_MKL', '0')]
mxnet_mpi_lib.define_macros += [('MSHADOW_USE_F16C', '0')]
mxnet_mpi_lib.include_dirs = options['INCLUDES']
Expand Down
12 changes: 6 additions & 6 deletions test/data/expected_buildkite_pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1342,8 +1342,8 @@ steps:
automatic: true
agents:
queue: cpu
- label: ':muscle: Test MXNet MNIST (test-cpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0)'
command: bash -c " OMP_NUM_THREADS=1 \$(cat /mpirun_command) python /horovod/examples/mxnet_mnist.py"
- label: ':muscle: Test MXNet2 MNIST (test-cpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0)'
command: bash -c " OMP_NUM_THREADS=1 \$(cat /mpirun_command) python /horovod/examples/mxnet2_mnist.py"
plugins:
- docker-compose#v2.6.0:
run: test-cpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0
Expand Down Expand Up @@ -1412,8 +1412,8 @@ steps:
automatic: true
agents:
queue: cpu
- label: ':muscle: Single MXNet MNIST (test-cpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0)'
command: bash -c " python /horovod/examples/mxnet_mnist.py --epochs 3"
- label: ':muscle: Single MXNet2 MNIST (test-cpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0)'
command: bash -c " python /horovod/examples/mxnet2_mnist.py --epochs 3"
plugins:
- docker-compose#v2.6.0:
run: test-cpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0
Expand Down Expand Up @@ -2744,8 +2744,8 @@ steps:
automatic: true
agents:
queue: 2x-gpu-g4
- label: ':muscle: Test MXNet MNIST (test-gpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0)'
command: bash -c " OMP_NUM_THREADS=1 \$(cat /mpirun_command) python /horovod/examples/mxnet_mnist.py"
- label: ':muscle: Test MXNet2 MNIST (test-gpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0)'
command: bash -c " OMP_NUM_THREADS=1 \$(cat /mpirun_command) python /horovod/examples/mxnet2_mnist.py"
plugins:
- docker-compose#v2.6.0:
run: test-gpu-openmpi-py3_6-tfhead-kerashead-torchhead-mxnethead-pyspark2_4_0
Expand Down
13 changes: 6 additions & 7 deletions test/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,13 +847,12 @@ class SimpleNet(HybridBlock):
def __init__(self, layer_num=6, **kwargs):
super(SimpleNet, self).__init__(**kwargs)
self._layer_num = layer_num
with self.name_scope():
self.ln_l = nn.HybridSequential()
self.dense_l = nn.HybridSequential()
for i in range(layer_num):
self.dense_l.add(nn.Dense(units=32 + layer_num - 1 - i,
flatten=False))
self.ln_l.add(nn.LayerNorm())
self.ln_l = nn.HybridSequential()
self.dense_l = nn.HybridSequential()
for i in range(layer_num):
self.dense_l.add(nn.Dense(units=32 + layer_num - 1 - i,
flatten=False))
self.ln_l.add(nn.LayerNorm())

def hybrid_forward(self, F, data):
"""
Expand Down