From 148479b515ad4220ede19b64f52a9522ede23d01 Mon Sep 17 00:00:00 2001 From: muli Date: Fri, 9 Oct 2015 21:49:09 -0400 Subject: [PATCH 01/19] [IO] add back num_parts and part_index for imgrec --- src/io/iter_image_recordio.cc | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/io/iter_image_recordio.cc b/src/io/iter_image_recordio.cc index 37bfe0020b6d..6ca610e8a410 100644 --- a/src/io/iter_image_recordio.cc +++ b/src/io/iter_image_recordio.cc @@ -101,6 +101,11 @@ struct ImageRecParserParam : public dmlc::Parameter { int preprocess_threads; /*! \brief whether to remain silent */ bool verbose; + /*! \brief partition the data into multiple parts */ + int num_parts; + /*! \brief the index of the part will read*/ + int part_index; + // declare parameters DMLC_DECLARE_PARAMETER(ImageRecParserParam) { DMLC_DECLARE_FIELD(path_imglist).set_default("") @@ -116,6 +121,10 @@ struct ImageRecParserParam : public dmlc::Parameter { .describe("Backend Param: Number of thread to do preprocessing."); DMLC_DECLARE_FIELD(verbose).set_default(true) .describe("Auxiliary Param: Whether to output parser information."); + DMLC_DECLARE_FIELD(num_parts).set_default(1) + .describe("partition the data into multiple parts"); + DMLC_DECLARE_FIELD(part_index).set_default(0) + .describe("the index of the part will read"); } }; @@ -203,12 +212,9 @@ inline void ImageRecordIOParser::Init( LOG(INFO) << "ImageRecordIOParser: " << param_.path_imgrec << ", use " << threadget << " threads for decoding.."; } - // TODO(mu, tianjun) add DMLC env variable to detect parition - const int part_index = 0; - const int num_parts = 1; source_ = dmlc::InputSplit::Create( - param_.path_imgrec.c_str(), part_index, - num_parts, "recordio"); + param_.path_imgrec.c_str(), param_.part_index, + param_.num_parts, "recordio"); // use 64 MB chunk when possible source_->HintChunkSize(8 << 20UL); #else From 7fbdc3d9f8e957629b7e3792e761317c0685534f Mon Sep 17 00:00:00 2001 From: muli Date: Fri, 9 Oct 2015 22:42:03 -0400 Subject: [PATCH 02/19] [python] support dist kvstore --- python/mxnet/kvstore.py | 2 +- python/mxnet/model.py | 63 ++++++++++++++++++++------------------- python/mxnet/optimizer.py | 3 +- 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index ffd7433281ae..640a11cc394a 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -261,7 +261,7 @@ def set_optimizer(self, optimizer): raise self._send_command_to_servers(0, optim_str) else: - self._set_updater(opt.optimizer_clossure(optimizer)) + self._set_updater(opt.get_updater(optimizer)) def get_rank(self): """Get the rank of this worker node diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 719f5385ebe7..abebf8a9e815 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -15,6 +15,7 @@ from .context import Context, cpu from .initializer import Uniform from collections import namedtuple +from .optimizer import get_updater BASE_ESTIMATOR = object @@ -180,7 +181,7 @@ def _train_multi_device(symbol, ctx, input_shape, update_on_kvstore : boolean, optional Whether to perform parameter update on kvstore instead of training device. - kvstore_type : {'local', 'device'}, optional + kvstore_type : {'local', 'device', 'dist'}, optional Type of kvstore used for synchronization. logger : logging logger @@ -220,39 +221,41 @@ def _train_multi_device(symbol, ctx, input_shape, texec.copy_params_from(arg_params, aux_params) # ky value store - kv = kvstore.create(kvstore_type) if num_device != 1 else None - if kv is None or kvstore_type == 'device': - update_on_kvstore = False - else: + if kvstore_type == 'dist': + kv = kvstore.create(kvstore_type) + update_on_kvstore = True + elif num_device != 1: + kv = kvstore.create(kvstore_type) # auto decide update_on_kvstore if update_on_kvstore is None: max_size = max(np.prod(param.shape) for param in arg_params.values()) update_on_kvstore = max_size < 1024 * 1024 * 16 logging.info('Auto-select update_on_kvstore=%s', str(update_on_kvstore)) + else: + # don't use kvstore for single machine and single device + update_on_kvstore = False + kv = None - opt_state_blocks = [] - # If there are multiple devices, initialize the weights. - for index, pair in enumerate(zip(arg_blocks, grad_blocks)): - arg_list, grad_list = pair - if grad_list[0] is not None: - if kv: - kv.init(index, arg_list[0]) - # attach state direct to weight - if update_on_kvstore: - opt_state_blocks.append(nd.zeros(arg_list[0].shape, cpu())) - else: - opt_list = [optimizer.create_state(index, w) for w in arg_list] - opt_state_blocks.append(opt_list) - else: - opt_state_blocks.append(None) + # init optimizer before give it to kv or get_updater + optimizer.begin_round(begin_round) + + if not update_on_kvstore: + updater = get_updater(optimizer) - def kv_updater(index, grad, weight): - """Internal updater on KVstore, used when update_on_kvstore=True.""" - optimizer.update(index, weight, grad, opt_state_blocks[index]) + if kv: + # init optimizer + if update_on_kvstore: + kv.set_optimizer(optimizer) - # pylint: disable=protected-access - if update_on_kvstore: - kv._set_updater(kv_updater) + # init kv + for index, pair in enumerate(zip(arg_blocks, grad_blocks)): + arg_list, grad_list = pair + if grad_list[0] is not None: + kv.init(index, arg_list[0]) + + # pull the weight back + if update_on_kvstore: + kv.pull(index, arg_list, priority=-index) # Input and output data structure data_index, label_index = _check_arguments(symbol) @@ -265,7 +268,6 @@ def kv_updater(index, grad, weight): for iteration in range(begin_round, end_round): # Training phase tic = time.time() - optimizer.begin_round(iteration) eval_metric.reset() nbatch = 0 # Iterate over training data. @@ -297,10 +299,9 @@ def kv_updater(index, grad, weight): # pull back the sum gradients, to the same locations. kv.pull(index, grad_list, priority=-index) if not update_on_kvstore: - opt_list = opt_state_blocks[index] - # optimizea - for w, g, state in zip(arg_list, grad_list, opt_list): - optimizer.update(index, w, g, state) + for w, g in zip(arg_list, grad_list): + updater(index, g, w) + nbatch += 1 # epoch callback (for print purpose) if epoch_end_callback != None: diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index da336aff29e4..ad2b1fbdd6e6 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -148,8 +148,7 @@ def create(name, rescale_grad=1, **kwargs): else: raise ValueError('Cannot find optimizer %s' % name) - -def optimizer_clossure(optimizer): +def get_updater(optimizer): """Return a clossure of the updater needed for kvstore Parameters From 2006810fa7dda0f2c3946824a763a5de5d391fc5 Mon Sep 17 00:00:00 2001 From: muli Date: Fri, 9 Oct 2015 23:32:25 -0400 Subject: [PATCH 03/19] [IO] multiple part support in mnist --- src/io/iter_mnist.cc | 43 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/src/io/iter_mnist.cc b/src/io/iter_mnist.cc index a1f03dbd8e83..064e6e9e3e6f 100644 --- a/src/io/iter_mnist.cc +++ b/src/io/iter_mnist.cc @@ -31,6 +31,10 @@ struct MNISTParam : public dmlc::Parameter { bool flat; /*! \brief random seed */ int seed; + /*! \brief partition the data into multiple parts */ + int num_parts; + /*! \brief the index of the part will read*/ + int part_index; // declare parameters DMLC_DECLARE_PARAMETER(MNISTParam) { DMLC_DECLARE_FIELD(image).set_default("./train-images-idx3-ubyte") @@ -47,6 +51,10 @@ struct MNISTParam : public dmlc::Parameter { .describe("Augmentation Param: Random Seed."); DMLC_DECLARE_FIELD(silent).set_default(false) .describe("Auxiliary Param: Whether to print out data info."); + DMLC_DECLARE_FIELD(num_parts).set_default(1) + .describe("partition the data into multiple parts"); + DMLC_DECLARE_FIELD(part_index).set_default(0) + .describe("the index of the part will read"); } }; @@ -113,13 +121,33 @@ class MNISTIter: public IIterator { } private: + inline void GetPart(int count, int* start, int *end) { + CHECK_GE(param_.part_index, 0); + CHECK_GT(param_.num_parts, 0); + CHECK_GT(param_.num_parts, param_.part_index); + + *start = static_cast( + static_cast(count) / param_.num_parts * param_.part_index); + *end = static_cast( + static_cast(count) / param_.num_parts * (param_.part_index+1)); + } + inline void LoadImage(void) { - dmlc::Stream *stdimg = dmlc::Stream::Create(param_.image.c_str(), "r"); + // dmlc::Stream *stdimg = dmlc::Stream::Create(param_.image.c_str(), "r"); + dmlc::SeekStream* stdimg + = dmlc::SeekStream::CreateForRead(param_.image.c_str()); ReadInt(stdimg); int image_count = ReadInt(stdimg); int image_rows = ReadInt(stdimg); int image_cols = ReadInt(stdimg); + int start, end; + GetPart(image_count, &start, &end); + image_count = end - start; + if (start > 0) { + stdimg->Seek(stdimg->Tell() + start * image_rows * image_cols); + } + img_.shape_ = mshadow::Shape3(image_count, image_rows, image_cols); img_.stride_ = img_.size(2); @@ -139,9 +167,20 @@ class MNISTIter: public IIterator { delete stdimg; } inline void LoadLabel(void) { - dmlc::Stream *stdlabel = dmlc::Stream::Create(param_.label.c_str(), "r"); + // dmlc::Stream *stdlabel = dmlc::Stream::Create(param_.label.c_str(), "r"); + + dmlc::SeekStream* stdlabel + = dmlc::SeekStream::CreateForRead(param_.label.c_str()); ReadInt(stdlabel); int labels_count = ReadInt(stdlabel); + + int start, end; + GetPart(labels_count, &start, &end); + labels_count = end - start; + if (start > 0) { + stdlabel->Seek(stdlabel->Tell() + start); + } + labels_.resize(labels_count); for (int i = 0; i < labels_count; ++i) { unsigned char ch; From e3353ab15632ef34e5c98fcb4399511729e54945 Mon Sep 17 00:00:00 2001 From: muli Date: Sat, 10 Oct 2015 00:25:08 -0400 Subject: [PATCH 04/19] [kvstore] add test_mlp.py --- python/mxnet/kvstore.py | 3 ++ python/mxnet/model.py | 22 +++++++-- src/io/iter_mnist.cc | 2 - src/kvstore/kvstore_dist.h | 2 + tests/python/distributed/test_mlp.py | 73 ++++++++++++++++++++++++++++ 5 files changed, 95 insertions(+), 7 deletions(-) create mode 100755 tests/python/distributed/test_mlp.py diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 640a11cc394a..29145bbbc4cf 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -169,6 +169,9 @@ def push(self, key, value, priority=0): self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) + # self._wait(key) + # self._barrier() + def pull(self, key, out=None, priority=0): """ Pull a single value or a sequence of values from the store. diff --git a/python/mxnet/model.py b/python/mxnet/model.py index abebf8a9e815..be607323cff2 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -130,7 +130,7 @@ def _train_multi_device(symbol, ctx, input_shape, train_data, eval_data=None, eval_metric=None, iter_end_callback=None, epoch_end_callback=None, update_on_kvstore=None, kvstore_type='local', - logger=None): + kv=None, logger=None): """Internal training function on multiple devices. This function will also work for single device as well. @@ -184,6 +184,9 @@ def _train_multi_device(symbol, ctx, input_shape, kvstore_type : {'local', 'device', 'dist'}, optional Type of kvstore used for synchronization. + kv : kvstore, optional + An instance of kvstore. It overwrite both kvstore_type and update_on_kvstore + logger : logging logger When not specified, default logger will be used. @@ -221,7 +224,9 @@ def _train_multi_device(symbol, ctx, input_shape, texec.copy_params_from(arg_params, aux_params) # ky value store - if kvstore_type == 'dist': + if kv is not None: + update_on_kvstore = True + elif kvstore_type == 'dist': kv = kvstore.create(kvstore_type) update_on_kvstore = True elif num_device != 1: @@ -609,7 +614,7 @@ def predict(self, X): def fit(self, X, y=None, eval_data=None, eval_metric='acc', iter_end_callback=None, epoch_end_callback=None, update_on_kvstore=None, kvstore_type='local', - logger=None): + kvstore=None, logger=None): """Fit the model. Parameters @@ -643,6 +648,9 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', kvstore_type : {'local', 'device'}, optional Type of kvstore used for synchronization, usually no need to set. + kvstore : kvstore, optional + An instance of kvstore. It overwrite both kvstore_type and update_on_kvstore + logger : logging logger, optional When not specified, default logger will be used. """ @@ -675,6 +683,7 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', epoch_end_callback=epoch_end_callback, update_on_kvstore=update_on_kvstore, kvstore_type=kvstore_type, + kv=kvstore, logger=logger) def save(self, prefix, iteration=None): @@ -740,7 +749,7 @@ def load(prefix, iteration, ctx=None, **kwargs): def create(symbol, X, y=None, ctx=None, num_round=None, optimizer='sgd', initializer=Uniform(0.01), eval_data=None, eval_metric='acc', iter_end_callback=None, - update_on_kvstore=None, kvstore_type='local', + update_on_kvstore=None, kvstore_type='local', kvstore=None, logger=None, **kwargs): """Functional style to create a model. @@ -790,6 +799,9 @@ def create(symbol, X, y=None, ctx=None, kvstore_type : {'local', 'device'}, optional Type of kvstore used for synchronization, usually no need to set. + kvstore : kvstore, optional + An instance of kvstore. It overwrite both kvstore_type and update_on_kvstore + logger : logging logger, optional """ model = FeedForward(symbol, ctx=ctx, num_round=num_round, @@ -797,6 +809,6 @@ def create(symbol, X, y=None, ctx=None, model.fit(X, y, eval_data=eval_data, eval_metric=eval_metric, iter_end_callback=iter_end_callback, update_on_kvstore=update_on_kvstore, - kvstore_type=kvstore_type, + kvstore_type=kvstore_type, kvstore=kvstore, logger=logger) return model diff --git a/src/io/iter_mnist.cc b/src/io/iter_mnist.cc index 064e6e9e3e6f..6ac3415237a1 100644 --- a/src/io/iter_mnist.cc +++ b/src/io/iter_mnist.cc @@ -133,7 +133,6 @@ class MNISTIter: public IIterator { } inline void LoadImage(void) { - // dmlc::Stream *stdimg = dmlc::Stream::Create(param_.image.c_str(), "r"); dmlc::SeekStream* stdimg = dmlc::SeekStream::CreateForRead(param_.image.c_str()); ReadInt(stdimg); @@ -167,7 +166,6 @@ class MNISTIter: public IIterator { delete stdimg; } inline void LoadLabel(void) { - // dmlc::Stream *stdlabel = dmlc::Stream::Create(param_.label.c_str(), "r"); dmlc::SeekStream* stdlabel = dmlc::SeekStream::CreateForRead(param_.label.c_str()); diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index 2721accd5c04..6bfcbe19a9ba 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -10,6 +10,7 @@ #include "./kvstore_local.h" #include "./mxnet_ps_node.h" #include "mxnet/engine.h" +// #include "dmlc/parameter.h" #include "ps.h" #include "base/range.h" @@ -43,6 +44,7 @@ class KVStoreDist : public KVStoreLocal { // stop the executor at servers SendCommandToServers(CommandID::kStop, ""); } + Barrier(); ps::StopSystem(); } } diff --git a/tests/python/distributed/test_mlp.py b/tests/python/distributed/test_mlp.py new file mode 100755 index 000000000000..7b5c55588644 --- /dev/null +++ b/tests/python/distributed/test_mlp.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# pylint: skip-file + +import mxnet as mx +import numpy as np +import os, sys +import pickle as pickle +import logging +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.append(os.path.join(curr_path, '../common/')) +import models +import get_data + +# symbol net +batch_size = 100 +data = mx.symbol.Variable('data') +fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128) +act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") +fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64) +act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu") +fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10) +softmax = mx.symbol.Softmax(fc3, name = 'sm') + +def accuracy(label, pred): + py = np.argmax(pred, axis=1) + return np.sum(py == label) / float(label.size) + +num_round = 4 +prefix = './mlp' + +kv = mx.kvstore.create('dist') +batch_size /= kv.get_num_workers() + +#check data +get_data.GetMNIST_ubyte() + +train_dataiter = mx.io.MNISTIter( + image="data/train-images-idx3-ubyte", + label="data/train-labels-idx1-ubyte", + data_shape=(784,), num_parts=kv.get_num_workers(), part_index=kv.get_rank(), + batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10) +val_dataiter = mx.io.MNISTIter( + image="data/t10k-images-idx3-ubyte", + label="data/t10k-labels-idx1-ubyte", + data_shape=(784,), + batch_size=batch_size, shuffle=True, flat=True, silent=False) + +def test_mlp(): + logging.basicConfig(level=logging.DEBUG) + + model = mx.model.FeedForward.create( + softmax, + X=train_dataiter, + eval_data=val_dataiter, + eval_metric=mx.metric.np(accuracy), + ctx=[mx.cpu(i) for i in range(1)], + num_round=num_round, + learning_rate=0.05, wd=0.0004, + momentum=0.9, + kvstore=kv, + ) + logging.info('Finish traning...') + prob = model.predict(val_dataiter) + logging.info('Finish predict...') + val_dataiter.reset() + y = np.concatenate([label.asnumpy() for _, label in val_dataiter]).astype('int') + py = np.argmax(prob, axis=1) + acc = float(np.sum(py == y)) / len(y) + logging.info('final accuracy = %f', acc) + assert(acc > 0.93) + +if __name__ == "__main__": + test_mlp() From f9fcf8dfba3c6cb61809ce537f1c5e2b36891d14 Mon Sep 17 00:00:00 2001 From: muli Date: Sat, 10 Oct 2015 18:01:32 -0400 Subject: [PATCH 05/19] [doc] multi node --- doc/multi_node.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 doc/multi_node.md diff --git a/doc/multi_node.md b/doc/multi_node.md new file mode 100644 index 000000000000..8200e08842b6 --- /dev/null +++ b/doc/multi_node.md @@ -0,0 +1,12 @@ +# Multi-devices and multi-machines + +![ps arch](https://raw.githubusercontent.com/dmlc/dmlc.github.io/master/img/mxnet/multi-node/ps_arch.png) + +| kvstore type | updt on kvstore | multi-devs | multi-workers | #ex per updt | max delay | updt place | +| :--- | :--- | ---:| ---:| ---:| ---:| ---:| +| none | no | no | no | *b* | *0* | worker\_0's dev\_0 | +| local | yes | yes | no | *b* | *0* | worker_0's cpu | +| local | no | yes | no | *b* | *0* | worker\_0's devs | +| device | no | yes | no | *b* | *0* | worker\_0's devs | +| dist | yes | yes | yes | *b* | *n* | servers | +| dist | no | yes | yes | *b × n* | *0* | workers' cpu | From a76453d116fc80304df402da31889ae7ccc7d1b6 Mon Sep 17 00:00:00 2001 From: muli Date: Sat, 10 Oct 2015 18:05:29 -0400 Subject: [PATCH 06/19] [doc] multi-node --- doc/multi_node.md | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/doc/multi_node.md b/doc/multi_node.md index 8200e08842b6..a01ac4119024 100644 --- a/doc/multi_node.md +++ b/doc/multi_node.md @@ -1,12 +1,14 @@ # Multi-devices and multi-machines -![ps arch](https://raw.githubusercontent.com/dmlc/dmlc.github.io/master/img/mxnet/multi-node/ps_arch.png) +Architecture -| kvstore type | updt on kvstore | multi-devs | multi-workers | #ex per updt | max delay | updt place | + + +| kvstore type | update on kvstore | multi-devices | multi-workers | #ex per update | max delay | update place | | :--- | :--- | ---:| ---:| ---:| ---:| ---:| -| none | no | no | no | *b* | *0* | worker\_0's dev\_0 | -| local | yes | yes | no | *b* | *0* | worker_0's cpu | -| local | no | yes | no | *b* | *0* | worker\_0's devs | -| device | no | yes | no | *b* | *0* | worker\_0's devs | +| none | no | no | no | *b* | *0* | worker0<\sub>'s dev0<\sub> | +| local | yes | yes | no | *b* | *0* | worker0<\sub>'s cpu | +| local | no | yes | no | *b* | *0* | worker0<\sub>'s devs | +| device | no | yes | no | *b* | *0* | worker0<\sub>'s devs | | dist | yes | yes | yes | *b* | *n* | servers | | dist | no | yes | yes | *b × n* | *0* | workers' cpu | From facfcdef6c9a816b797589afacb530c41906c849 Mon Sep 17 00:00:00 2001 From: muli Date: Sat, 10 Oct 2015 18:06:25 -0400 Subject: [PATCH 07/19] [doc] mn --- doc/multi_node.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/multi_node.md b/doc/multi_node.md index a01ac4119024..b9aa2234c38a 100644 --- a/doc/multi_node.md +++ b/doc/multi_node.md @@ -2,13 +2,13 @@ Architecture - + | kvstore type | update on kvstore | multi-devices | multi-workers | #ex per update | max delay | update place | | :--- | :--- | ---:| ---:| ---:| ---:| ---:| -| none | no | no | no | *b* | *0* | worker0<\sub>'s dev0<\sub> | -| local | yes | yes | no | *b* | *0* | worker0<\sub>'s cpu | -| local | no | yes | no | *b* | *0* | worker0<\sub>'s devs | -| device | no | yes | no | *b* | *0* | worker0<\sub>'s devs | +| none | no | no | no | *b* | *0* | worker0's dev0<\sub> | +| local | yes | yes | no | *b* | *0* | worker0's cpu | +| local | no | yes | no | *b* | *0* | worker0's devs | +| device | no | yes | no | *b* | *0* | worker0's devs | | dist | yes | yes | yes | *b* | *n* | servers | | dist | no | yes | yes | *b × n* | *0* | workers' cpu | From 5b95f3441c07f44ca60c4623c8b29e7d5f11f63d Mon Sep 17 00:00:00 2001 From: muli Date: Sat, 10 Oct 2015 18:29:03 -0400 Subject: [PATCH 08/19] [doc] multip --- doc/multi_node.md | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/doc/multi_node.md b/doc/multi_node.md index b9aa2234c38a..621f98385431 100644 --- a/doc/multi_node.md +++ b/doc/multi_node.md @@ -1,14 +1,31 @@ # Multi-devices and multi-machines -Architecture +## Architecture + +A device could be a GPU card, CPU, or other computational units. -| kvstore type | update on kvstore | multi-devices | multi-workers | #ex per update | max delay | update place | +- *b*: the batch size set by users +- *k*: the number of devices used on a worker (could vary for different workers) +- *n*: the number of workers (often mean machines) + +- *number examples per update*: for each update, the number of examples used to + calculate the averaged gradients. Often the larger, the slower the convergence. +- *number examples per device*: the number of examples batched to one device + each time. Often the larger, the better the performance. +- *max delay*: The maximal delay of the weight a worker can get. Given a worker, + a delay *d* for weight *w* means when this worker uses *w* (to calculate the + gradient), *w* have been already updated by *d* times on some other places. A + larger delay often improves the performance, but may slows down the + convergence. + + +| kvstore type | update on kvstore | multiple? | #ex per dev | #ex per update | max delay | update place | | :--- | :--- | ---:| ---:| ---:| ---:| ---:| -| none | no | no | no | *b* | *0* | worker0's dev0<\sub> | -| local | yes | yes | no | *b* | *0* | worker0's cpu | -| local | no | yes | no | *b* | *0* | worker0's devs | -| device | no | yes | no | *b* | *0* | worker0's devs | -| dist | yes | yes | yes | *b* | *n* | servers | -| dist | no | yes | yes | *b × n* | *0* | workers' cpu | +| none | no | no | *b* | *b* | *0* | dev0 on worker0 | +| local | yes | devs| *b/k* | *b* | *0* | cpu on worker0 | +| local | no | devs | *b/k* | *b* | *0* | devs on worker0 | +| device | no | devs | *b/k* |*b* | *0* | devs on worker0 | +| dist | yes | devs + workers | yes | *b/k* |*b* | *n* | servers | +| dist | no | devs + workers | *b/k* | *b × n* | *0* | cpus on workers | From 815143397421d5e65385c854c6fb3ccdbbfda9b9 Mon Sep 17 00:00:00 2001 From: muli Date: Sat, 10 Oct 2015 18:31:42 -0400 Subject: [PATCH 09/19] [doc] multi --- doc/multi_node.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/doc/multi_node.md b/doc/multi_node.md index 621f98385431..b09630f58083 100644 --- a/doc/multi_node.md +++ b/doc/multi_node.md @@ -21,11 +21,11 @@ A device could be a GPU card, CPU, or other computational units. convergence. -| kvstore type | update on kvstore | multiple? | #ex per dev | #ex per update | max delay | update place | +| kvstore type | update on kvstore | multi devices | multi workers | #ex per device | #ex per update | max delay | update place | | :--- | :--- | ---:| ---:| ---:| ---:| ---:| -| none | no | no | *b* | *b* | *0* | dev0 on worker0 | -| local | yes | devs| *b/k* | *b* | *0* | cpu on worker0 | -| local | no | devs | *b/k* | *b* | *0* | devs on worker0 | -| device | no | devs | *b/k* |*b* | *0* | devs on worker0 | -| dist | yes | devs + workers | yes | *b/k* |*b* | *n* | servers | -| dist | no | devs + workers | *b/k* | *b × n* | *0* | cpus on workers | +| none | no | no | no | *b* | *b* | *0* | dev0 on worker0 | +| local | yes | yes | no | *b / k* | *b* | *0* | cpu on worker0 | +| local | no | yes | no | *b/k* | *b* | *0* | devs on worker0 | +| device | no | yes | no | *b/k* |*b* | *0* | devs on worker0 | +| dist | yes | yes | no | yes | *b/k* |*b* | *n* | servers | +| dist | no | yes | no | *b/k* | *b × n* | *0* | cpus on workers | From 829f143850b9eea562b67720b41258d66406f782 Mon Sep 17 00:00:00 2001 From: muli Date: Sat, 10 Oct 2015 18:32:00 -0400 Subject: [PATCH 10/19] [doc] update --- doc/multi_node.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/multi_node.md b/doc/multi_node.md index b09630f58083..2eaac8da6a80 100644 --- a/doc/multi_node.md +++ b/doc/multi_node.md @@ -22,7 +22,7 @@ A device could be a GPU card, CPU, or other computational units. | kvstore type | update on kvstore | multi devices | multi workers | #ex per device | #ex per update | max delay | update place | -| :--- | :--- | ---:| ---:| ---:| ---:| ---:| +| :--- | :--- | ---:| ---:| ---:| ---:| ---:| ---:| | none | no | no | no | *b* | *b* | *0* | dev0 on worker0 | | local | yes | yes | no | *b / k* | *b* | *0* | cpu on worker0 | | local | no | yes | no | *b/k* | *b* | *0* | devs on worker0 | From 591c99e1cdb50c620f088ffca1f8d652a997c3bd Mon Sep 17 00:00:00 2001 From: muli Date: Sat, 10 Oct 2015 18:57:39 -0400 Subject: [PATCH 11/19] [doc] update --- doc/multi_node.md | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/doc/multi_node.md b/doc/multi_node.md index 2eaac8da6a80..dd06d66269d2 100644 --- a/doc/multi_node.md +++ b/doc/multi_node.md @@ -6,26 +6,38 @@ A device could be a GPU card, CPU, or other computational units. -- *b*: the batch size set by users -- *k*: the number of devices used on a worker (could vary for different workers) -- *n*: the number of workers (often mean machines) +- **n** : the number of workers (often mean machines) +- **k** : the number of devices used on a worker (could vary for different workers) +- **b** : the batch size set by users -- *number examples per update*: for each update, the number of examples used to +- **number examples per update** : for each update, the number of examples used to calculate the averaged gradients. Often the larger, the slower the convergence. -- *number examples per device*: the number of examples batched to one device +- **number examples per device** : the number of examples batched to one device each time. Often the larger, the better the performance. -- *max delay*: The maximal delay of the weight a worker can get. Given a worker, +- **max delay** : The maximal delay of the weight a worker can get. Given a worker, a delay *d* for weight *w* means when this worker uses *w* (to calculate the gradient), *w* have been already updated by *d* times on some other places. A larger delay often improves the performance, but may slows down the convergence. -| kvstore type | update on kvstore | multi devices | multi workers | #ex per device | #ex per update | max delay | update place | +| kvstore type | multi devices | multi workers | #ex per device | #ex per update | max delay | | :--- | :--- | ---:| ---:| ---:| ---:| ---:| ---:| -| none | no | no | no | *b* | *b* | *0* | dev0 on worker0 | -| local | yes | yes | no | *b / k* | *b* | *0* | cpu on worker0 | -| local | no | yes | no | *b/k* | *b* | *0* | devs on worker0 | -| device | no | yes | no | *b/k* |*b* | *0* | devs on worker0 | -| dist | yes | yes | no | yes | *b/k* |*b* | *n* | servers | -| dist | no | yes | no | *b/k* | *b × n* | *0* | cpus on workers | +| `none` | no | no | *b* | *b* | *0* | +| `local` / `device` | yes | no | *b / k* | *b* | *0* | +| `dist_async` | yes | yes | yes | *b / k* | *b* | *n* +| `dist_sync` | no | yes | yes | *b / k* | *b × n* | *0* + + + + +- **2-4** for single machine with multiple devices. They are identical for + convergence, but may vary on performance. + - + dev0 on worker0 | + + cpu on worker0 | + devs on worker0 | + devs on worker0 | + | servers | + | cpus on workers | From d6590b245dc4a678f484d42d35159e252793583b Mon Sep 17 00:00:00 2001 From: muli Date: Sat, 10 Oct 2015 19:39:44 -0400 Subject: [PATCH 12/19] [doc] update --- doc/multi_node.md | 63 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 13 deletions(-) diff --git a/doc/multi_node.md b/doc/multi_node.md index dd06d66269d2..6fbdc3c708bb 100644 --- a/doc/multi_node.md +++ b/doc/multi_node.md @@ -21,23 +21,60 @@ A device could be a GPU card, CPU, or other computational units. convergence. -| kvstore type | multi devices | multi workers | #ex per device | #ex per update | max delay | -| :--- | :--- | ---:| ---:| ---:| ---:| ---:| ---:| +| kvstore type | multi-devices | multi-workers | #ex per device | #ex per update | max delay | +| :--- | --- | --- | --- | --- | --- | | `none` | no | no | *b* | *b* | *0* | | `local` / `device` | yes | no | *b / k* | *b* | *0* | -| `dist_async` | yes | yes | yes | *b / k* | *b* | *n* -| `dist_sync` | no | yes | yes | *b / k* | *b × n* | *0* +| `dist_sync` | yes | yes | *b / k* | *b × n* | *0* | +| `dist_async` | yes | yes | *b / k* | *b* | inf | +## Multiple devices on a single machine +Both `local` and `device` can handle the situation that a single machine with +multiple devices. They give the some results (model accuracy) as the single +device case. But comparing to the latter, each device only processes *1 / k* +examples each time (also consumes *1 / k* device memory), so we often increase +the batch size *b* for better system performance. -- **2-4** for single machine with multiple devices. They are identical for - convergence, but may vary on performance. - - - dev0 on worker0 | +We can further fine tune the system performance by specifying where to average +the gradients over all devices, and where to update the weight: - cpu on worker0 | - devs on worker0 | - devs on worker0 | - | servers | - | cpus on workers | +| case | kvstore type | update on kvstore | average gradient | perform update | +| :--- | :--- | :--- | --- | --- | --- | +| 1 | 'local' | yes | CPU | CPU | +| 2 | 'local' | no | CPU | all devices | +| 3 | 'device | yes | a device | all devices | + +- On case 1, gradients are first copied to main memory, next averaged on CPU, + and then update the weight on CPU. It is suitable when the average size of + weights are not large and there are a large number of weight. For example the + google Inception network. + +- Case 2 is similar to 1 except that the averaged gradients are copied back to + the devices, and then weights are updated on devices. It is faster than 1 when + the weight size is large so we can use the device to accelerate the computation + (but we increase the workload by *k* times). Examples are AlexNet on + imagenet. + +- Case 3 is similar to 1 except that the gradient are averaged on a chosen + device. It may take advantage of the possible device-to-device communication, and may + accelerate the averaging step. It is faster than 2 when the gradients are + huge. But it requires more device memory. + +## Multiple machines + + +Both `dist_async` and `dist_sync` can handle the multiple machines +situation. But they are different on both semantic and performance. + +- `dist_sync`: the gradients are first averaged on the servers, and then send to + back to workers for updating the weight. It is similar to `local` and + `update_on_kvstore=false` if we treat a machine as a device. It guarantees + almost identical convergence with the single machine single device situation + if reduces the batch size to *b / n*. However, it requires synchronization + between all workers, and therefore may harm the system performance. + +- `dist_async`: the gradient is sent to the servers, and the weight is updated + there. The weights a worker has may be stale. + (TODO) make the max delay be settable? From 4c47c27dd46bfd759cb5127cce804f5ef224238f Mon Sep 17 00:00:00 2001 From: muli Date: Sat, 10 Oct 2015 22:06:23 -0400 Subject: [PATCH 13/19] [doc] update --- doc/{ => developer-guide}/multi_node.md | 60 +++++++++++++++---------- 1 file changed, 36 insertions(+), 24 deletions(-) rename doc/{ => developer-guide}/multi_node.md (56%) diff --git a/doc/multi_node.md b/doc/developer-guide/multi_node.md similarity index 56% rename from doc/multi_node.md rename to doc/developer-guide/multi_node.md index 6fbdc3c708bb..fbf892025b4d 100644 --- a/doc/multi_node.md +++ b/doc/developer-guide/multi_node.md @@ -24,47 +24,59 @@ A device could be a GPU card, CPU, or other computational units. | kvstore type | multi-devices | multi-workers | #ex per device | #ex per update | max delay | | :--- | --- | --- | --- | --- | --- | | `none` | no | no | *b* | *b* | *0* | -| `local` / `device` | yes | no | *b / k* | *b* | *0* | +| `local` | yes | no | *b / k* | *b* | *0* | | `dist_sync` | yes | yes | *b / k* | *b × n* | *0* | | `dist_async` | yes | yes | *b / k* | *b* | inf | ## Multiple devices on a single machine -Both `local` and `device` can handle the situation that a single machine with -multiple devices. They give the some results (model accuracy) as the single -device case. But comparing to the latter, each device only processes *1 / k* -examples each time (also consumes *1 / k* device memory), so we often increase -the batch size *b* for better system performance. +KV store `local` synchronizes data over multiple devices on a single machine. +It gives the same results (e.g. model accuracy) as the single device case. But +comparing to the latter, assume there are *k* devices, then each device only +processes *1 / k* examples each time (also consumes *1 / k* device memory). We +often increase the batch size *b* for better system performance. -We can further fine tune the system performance by specifying where to average -the gradients over all devices, and where to update the weight: +When using `local`, the system will automatically chooses one of the following +three types. Their differences are on where to average +the gradients over all devices, and where to update the weight. -| case | kvstore type | update on kvstore | average gradient | perform update | -| :--- | :--- | :--- | --- | --- | --- | -| 1 | 'local' | yes | CPU | CPU | -| 2 | 'local' | no | CPU | all devices | -| 3 | 'device | yes | a device | all devices | -- On case 1, gradients are first copied to main memory, next averaged on CPU, +They produce +(almost) the same results, but may vary on speed. + +share the +same semantic + +They are semantically identical, but their speemay have different +speeds +We can further fine tune the system performance by specifying : + +| kvstore type | average gradient | perform update | +| :--- | :--- | --- | +| `local_update_cpu` | CPU | CPU | +| `local_allreduce_cpu` | CPU | all devices | +| `local_allreduce_device` | a device | all devices | + +- `local_update_cpu`, gradients are first copied to main memory, next averaged on CPU, and then update the weight on CPU. It is suitable when the average size of weights are not large and there are a large number of weight. For example the google Inception network. -- Case 2 is similar to 1 except that the averaged gradients are copied back to - the devices, and then weights are updated on devices. It is faster than 1 when - the weight size is large so we can use the device to accelerate the computation - (but we increase the workload by *k* times). Examples are AlexNet on - imagenet. +- `local_allreduce_cpu` is similar to `local_update_cpu` except that the + averaged gradients are copied back to the devices, and then weights are + updated on devices. It is faster than 1 when the weight size is large so we + can use the device to accelerate the computation (but we increase the workload + by *k* times). Examples are AlexNet on imagenet. -- Case 3 is similar to 1 except that the gradient are averaged on a chosen - device. It may take advantage of the possible device-to-device communication, and may - accelerate the averaging step. It is faster than 2 when the gradients are - huge. But it requires more device memory. +- `local_allreduce_device` is similar to `local_allreduce_cpu` except that the + gradient are averaged on a chosen device. It may take advantage of the + possible device-to-device communication, and may accelerate the averaging + step. It is faster than 2 when the gradients are huge. But it requires more + device memory. ## Multiple machines - Both `dist_async` and `dist_sync` can handle the multiple machines situation. But they are different on both semantic and performance. From 76c87240ea7dcbdbe4b07f8ece7faeb93c28f982 Mon Sep 17 00:00:00 2001 From: muli Date: Sat, 10 Oct 2015 23:23:16 -0400 Subject: [PATCH 14/19] [kvstore] refactor the kvstore types --- include/mxnet/c_api.h | 8 ++ include/mxnet/kvstore.h | 20 ++++- python/mxnet/kvstore.py | 13 ++++ python/mxnet/model.py | 105 +++++++++++++------------- src/c_api.cc | 7 ++ src/kvstore/kvstore.cc | 33 +++++--- tests/python/unittest/test_kvstore.py | 5 ++ 7 files changed, 128 insertions(+), 63 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index ce4e46b68498..c56126cddc0b 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -787,6 +787,14 @@ MXNET_DLL int MXKVStoreSetUpdater(KVStoreHandle handle, MXKVStoreUpdater updater); +/*! + * \brief get the type of the kvstore + * \param handle handle to the KVStore + * \param type a string type + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStoreGetType(KVStoreHandle handle, + const char** type); //-------------------------------------------- // Part 6: advanced KVStore for multi-machines //-------------------------------------------- diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index da7d94a75cfd..e97fbf06035f 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -25,13 +25,22 @@ class KVStore { /*! * \brief Factory function to create a new KVStore. - * \param type The type of the kvstore, can be "local" or "dist" - * - local works for multiple devices on a single machine (single process) - * - dist works for multi-machines (multiple processes) + * \param type The type of the kvstore, + * 'local' : multi-devices on a single machine. can be also + * 'local_update_cpu', 'local_allreduce_cpu' + * 'device' or 'local_allreduce_device' : same to local but use gpus for kv + * allreduce + * 'dist_sync' : multi-machines with BSP + * 'dist_async' : multi-machines with partical asynchronous * \return a new created KVStore. */ static KVStore *Create(const char *type = "local"); + /** + * \brief return the type + */ + inline std::string type() { return type_; } + /*! * \brief Initialize a list of key-value pair to the store. * @@ -269,6 +278,11 @@ class KVStore { * \brief the user-defined updater */ Updater updater_; + + /** + * \brief the kvstore type + */ + std::string type_; }; } // namespace mxnet diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 29145bbbc4cf..98eda59a9b15 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -63,6 +63,19 @@ def __init__(self, handle): def __del__(self): check_call(_LIB.MXKVStoreFree(self.handle)) + def get_type(self): + """Get the type of this kvstore + + Returns + ------- + type : str + the string type + """ + kv_type = ctypes.c_char_p() + check_call(_LIB.MXKVStoreGetType(self.handle, ctypes.byref(kv_type))) + return kv_type.value + + def init(self, key, value): """ Initialize a single or a sequence of key-value pairs into the store. diff --git a/python/mxnet/model.py b/python/mxnet/model.py index be607323cff2..c97f550cec0c 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -11,7 +11,7 @@ from . import symbol as sym from . import optimizer as opt from . import metric -from . import kvstore +from . import kvstore as kvs from .context import Context, cpu from .initializer import Uniform from collections import namedtuple @@ -129,8 +129,7 @@ def _train_multi_device(symbol, ctx, input_shape, begin_round, end_round, optimizer, train_data, eval_data=None, eval_metric=None, iter_end_callback=None, epoch_end_callback=None, - update_on_kvstore=None, kvstore_type='local', - kv=None, logger=None): + kvstore='local', logger=None): """Internal training function on multiple devices. This function will also work for single device as well. @@ -178,14 +177,15 @@ def _train_multi_device(symbol, ctx, input_shape, A callback that is invoked at end of each batch. This can be used to measure speed, get result from evaluation metric. etc. - update_on_kvstore : boolean, optional - Whether to perform parameter update on kvstore instead of training device. + kvstore: KVStore or str, optional + The KVStore or a string kvstore type: + 'local' : multi-devices on a single machine, will automatically + choose one from 'local_update_cpu', 'local_allreduce_cpu', and + 'local_allreduce_device' + 'dist_sync' : multi-machines with BSP + 'dist_async' : multi-machines with partical asynchronous - kvstore_type : {'local', 'device', 'dist'}, optional - Type of kvstore used for synchronization. - - kv : kvstore, optional - An instance of kvstore. It overwrite both kvstore_type and update_on_kvstore + In default uses 'local', often no need to change for single machiine. logger : logging logger When not specified, default logger will be used. @@ -223,23 +223,31 @@ def _train_multi_device(symbol, ctx, input_shape, for texec in train_execs: texec.copy_params_from(arg_params, aux_params) - # ky value store - if kv is not None: - update_on_kvstore = True - elif kvstore_type == 'dist': - kv = kvstore.create(kvstore_type) - update_on_kvstore = True - elif num_device != 1: - kv = kvstore.create(kvstore_type) - # auto decide update_on_kvstore - if update_on_kvstore is None: - max_size = max(np.prod(param.shape) for param in arg_params.values()) - update_on_kvstore = max_size < 1024 * 1024 * 16 - logging.info('Auto-select update_on_kvstore=%s', str(update_on_kvstore)) + # create kvstore + if isinstance(kvstore, KVStore): + kv = kvstore + elif isinstance(kvstore, str): + # create kvstore using the string type + if num_device is 1 and 'dist' not in kvstore: + # no need to use kv for single device and single machine + kv = None + else: + if kvstore is 'local': + # automatically select a proper local + max_size = max(np.prod(param.shape) for param in arg_params.values()) + if max_size < 1024 * 1024 * 16: + kvstore = 'local_update_cpu' + else: + kvstore = 'local_allreduce_cpu' + logging.info('Auto-select kvstore type = %s', kvstore) + kv = kvs.create(kvstore) else: - # don't use kvstore for single machine and single device - update_on_kvstore = False - kv = None + raise TypeError('kvstore must be either KVStore or str') + + # detect whether or not update weight on kvstore + update_on_kvstore = False + if kv and 'local_allreduce' in kv.get_type(): + update_on_kvstore = True # init optimizer before give it to kv or get_updater optimizer.begin_round(begin_round) @@ -613,8 +621,7 @@ def predict(self, X): def fit(self, X, y=None, eval_data=None, eval_metric='acc', iter_end_callback=None, epoch_end_callback=None, - update_on_kvstore=None, kvstore_type='local', - kvstore=None, logger=None): + kvstore='local', logger=None): """Fit the model. Parameters @@ -641,18 +648,19 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', A callback that is invoked at end of each batch For print purpose - update_on_kvstore: boolean, optional - Whether to perform parameter update on kvstore instead of training device. - By default, the trainer will automatically decide the policy. + kvstore: KVStore or str, optional + The KVStore or a string kvstore type: + 'local' : multi-devices on a single machine, will automatically + choose one from 'local_update_cpu', 'local_allreduce_cpu', and + 'local_allreduce_device' + 'dist_sync' : multi-machines with BSP + 'dist_async' : multi-machines with partical asynchronous - kvstore_type : {'local', 'device'}, optional - Type of kvstore used for synchronization, usually no need to set. - - kvstore : kvstore, optional - An instance of kvstore. It overwrite both kvstore_type and update_on_kvstore + In default uses 'local', often no need to change for single machiine. logger : logging logger, optional When not specified, default logger will be used. + """ X = self._init_iter(X, y, is_train=True) # Simply ignore the first example to get input_shape @@ -681,9 +689,7 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', eval_metric=eval_metric, iter_end_callback=iter_end_callback, epoch_end_callback=epoch_end_callback, - update_on_kvstore=update_on_kvstore, - kvstore_type=kvstore_type, - kv=kvstore, + kvstore=kvstore, logger=logger) def save(self, prefix, iteration=None): @@ -792,23 +798,20 @@ def create(symbol, X, y=None, ctx=None, A callback that is invoked at end of each iteration. This can be used to checkpoint model each iteration. - update_on_kvstore: boolean, optional - Whether to perform parameter update on kvstore instead of training device. - By default, the trainer will automatically decide the policy. - - kvstore_type : {'local', 'device'}, optional - Type of kvstore used for synchronization, usually no need to set. + kvstore: KVStore or str, optional + The KVStore or a string kvstore type: + 'local' : multi-devices on a single machine, will automatically + choose one from 'local_update_cpu', 'local_allreduce_cpu', and + 'local_allreduce_device' + 'dist_sync' : multi-machines with BSP + 'dist_async' : multi-machines with partical asynchronous - kvstore : kvstore, optional - An instance of kvstore. It overwrite both kvstore_type and update_on_kvstore - - logger : logging logger, optional + In default uses 'local', often no need to change for single machiine. """ model = FeedForward(symbol, ctx=ctx, num_round=num_round, optimizer=optimizer, initializer=initializer, **kwargs) model.fit(X, y, eval_data=eval_data, eval_metric=eval_metric, iter_end_callback=iter_end_callback, - update_on_kvstore=update_on_kvstore, - kvstore_type=kvstore_type, kvstore=kvstore, + kvstore=kvstore, logger=logger) return model diff --git a/src/c_api.cc b/src/c_api.cc index 026a293cdb02..e81b36097d4f 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -1088,3 +1088,10 @@ int MXKVStoreSendCommmandToServers(KVStoreHandle handle, cmd_id, std::string(cmd_body)); API_END(); } + +int MXKVStoreGetType(KVStoreHandle handle, + const char** type) { + API_BEGIN(); + *CHECK_NOTNULL(type) = static_cast(handle)->type().c_str(); + API_END(); +} diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc index c11cb98461dc..edd78e617b75 100644 --- a/src/kvstore/kvstore.cc +++ b/src/kvstore/kvstore.cc @@ -17,20 +17,35 @@ namespace mxnet { KVStore* KVStore::Create(const char *type_name) { std::string tname = type_name; - if (tname == "local") { - return new kvstore::KVStoreLocal(); - } else if (tname == "device") { - return new kvstore::KVStoreDevice(); - } else if (tname == "dist") { + std::transform(tname.begin(), tname.end(), tname.begin(), ::tolower); + KVStore* kv = nullptr; + if (tname == "local" || + tname == "local_update_cpu" || + tname == "local_allreduce_cpu") { + kv = new kvstore::KVStoreLocal(); + } else if (tname == "device" || + tname == "local_allreduce_device") { + tname = "local_allreduce_device"; + kv = new kvstore::KVStoreDevice(); + } else if (tname == "dist_async") { #if MXNET_USE_DIST_KVSTORE - return new kvstore::KVStoreDist(); + kv = new kvstore::KVStoreDist(); #else - LOG(FATAL) << "compile with USE_DIST_KVSTORE=1"; + LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to use " << tname; return nullptr; #endif // MXNET_USE_DIST_KVSTORE + } else if (tname == "dist_sync") { +#if MXNET_USE_DIST_KVSTORE + kv = new kvstore::KVStoreDist(); +#else + LOG(FATAL) << "compile with USE_DIST_KVSTORE=1 to use " << tname; + return nullptr; +#endif // MXNET_USE_DIST_KVSTORE + } else { + LOG(FATAL) << "Unknown KVStore type \"" << tname << "\""; } - LOG(FATAL) << "Unknown KVStore type \"" << type_name << "\""; - return nullptr; + kv->type_ = tname; + return kv; } } // namespace mxnet diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index ada94490ce86..b76b9dd536a3 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -104,8 +104,13 @@ def test_updater(dev = 'cpu'): for v in vv: check_diff_to_scalar(v, num_devs * num_push) +def test_get_type(): + kvtype = 'local_allreduce_cpu' + kv = mx.kv.create(kvtype) + assert kv.get_type() == kvtype if __name__ == '__main__': + test_get_type() test_single_kv_pair() test_list_kv_pair() test_aggregator() From d0224890f1edc8136cb96368afdc1b9a3af37e63 Mon Sep 17 00:00:00 2001 From: muli Date: Sat, 10 Oct 2015 23:36:09 -0400 Subject: [PATCH 15/19] [kvstore] lint --- python/mxnet/model.py | 9 ++------- src/io/iter_mnist.cc | 1 - tests/python/train/test_mlp.py | 6 +----- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index c97f550cec0c..ea926500dfba 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -193,10 +193,6 @@ def _train_multi_device(symbol, ctx, input_shape, Notes ----- - This function will inplace update the NDArrays in arg_parans and aux_states. - - Turning update_on_kvstore on and off can affect speed of multi-gpu training. - - It is auto selected by default. - - update_on_kvstore=True works well for inception type nets that contains many small weights. - - update_on_kvstore=False works better for Alexnet style net with bulk weights. """ if logger is None: logger = logging @@ -224,7 +220,7 @@ def _train_multi_device(symbol, ctx, input_shape, texec.copy_params_from(arg_params, aux_params) # create kvstore - if isinstance(kvstore, KVStore): + if isinstance(kvstore, kvs.KVStore): kv = kvstore elif isinstance(kvstore, str): # create kvstore using the string type @@ -755,8 +751,7 @@ def load(prefix, iteration, ctx=None, **kwargs): def create(symbol, X, y=None, ctx=None, num_round=None, optimizer='sgd', initializer=Uniform(0.01), eval_data=None, eval_metric='acc', iter_end_callback=None, - update_on_kvstore=None, kvstore_type='local', kvstore=None, - logger=None, **kwargs): + kvstore='local', logger=None, **kwargs): """Functional style to create a model. This function will be more consistent with functional diff --git a/src/io/iter_mnist.cc b/src/io/iter_mnist.cc index 6ac3415237a1..cb2e2a853e0d 100644 --- a/src/io/iter_mnist.cc +++ b/src/io/iter_mnist.cc @@ -166,7 +166,6 @@ class MNISTIter: public IIterator { delete stdimg; } inline void LoadLabel(void) { - dmlc::SeekStream* stdlabel = dmlc::SeekStream::CreateForRead(param_.label.c_str()); ReadInt(stdlabel); diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index 65f3d90d9e3d..85266e12df52 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -40,9 +40,6 @@ def accuracy(label, pred): def test_mlp(): # print logging by default logging.basicConfig(level=logging.DEBUG) - console = logging.StreamHandler() - console.setLevel(logging.DEBUG) - logging.getLogger('').addHandler(console) model = mx.model.FeedForward.create( softmax, @@ -53,8 +50,7 @@ def test_mlp(): ctx=[mx.cpu(i) for i in range(2)], num_round=num_round, learning_rate=0.1, wd=0.0004, - momentum=0.9, - update_on_kvstore=True) + momentum=0.9) logging.info('Finish traning...') prob = model.predict(val_dataiter) From 782cd274b10aa76706adff36a87c498b378e535c Mon Sep 17 00:00:00 2001 From: muli Date: Sat, 10 Oct 2015 23:48:16 -0400 Subject: [PATCH 16/19] [doc] for multipe machines --- doc/developer-guide/multi_node.md | 54 +++++++++++++++++-------------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/doc/developer-guide/multi_node.md b/doc/developer-guide/multi_node.md index fbf892025b4d..07259417aaf7 100644 --- a/doc/developer-guide/multi_node.md +++ b/doc/developer-guide/multi_node.md @@ -1,11 +1,34 @@ # Multi-devices and multi-machines -## Architecture +## Introduction -A device could be a GPU card, CPU, or other computational units. +MXNet uses a two-level *parameter server* for data synchronization. +- On the first layer, data are synchronized over multiple devices within a + single worker machine. A device could be a GPU card, CPU, or other computational + units. We often use sequential consistency model, also known as BSP, on this + level. + +- On the second layer, data are synchronize over multiple workers via server + machines. We can either use a sequential consistency model for guaranteed + convergence or an (partial)-asynchronous model for better system performance. + +## KVStore + +MXNet implemented the two-level parameter server in class *KVStore*. We +currently provide the following types: + +| kvstore type | multi-devices | multi-workers | #ex per device | #ex per update | max delay | +| :--- | --- | --- | --- | --- | --- | +| `none` | no | no | *b* | *b* | *0* | +| `local` | yes | no | *b / k* | *b* | *0* | +| `dist_sync` | yes | yes | *b / k* | *b × n* | *0* | +| `dist_async` | yes | yes | *b / k* | *b* | inf | + +where + - **n** : the number of workers (often mean machines) - **k** : the number of devices used on a worker (could vary for different workers) - **b** : the batch size set by users @@ -20,15 +43,6 @@ A device could be a GPU card, CPU, or other computational units. larger delay often improves the performance, but may slows down the convergence. - -| kvstore type | multi-devices | multi-workers | #ex per device | #ex per update | max delay | -| :--- | --- | --- | --- | --- | --- | -| `none` | no | no | *b* | *b* | *0* | -| `local` | yes | no | *b / k* | *b* | *0* | -| `dist_sync` | yes | yes | *b / k* | *b × n* | *0* | -| `dist_async` | yes | yes | *b / k* | *b* | inf | - - ## Multiple devices on a single machine KV store `local` synchronizes data over multiple devices on a single machine. @@ -41,23 +55,14 @@ When using `local`, the system will automatically chooses one of the following three types. Their differences are on where to average the gradients over all devices, and where to update the weight. - -They produce -(almost) the same results, but may vary on speed. - -share the -same semantic - -They are semantically identical, but their speemay have different -speeds -We can further fine tune the system performance by specifying : - | kvstore type | average gradient | perform update | | :--- | :--- | --- | | `local_update_cpu` | CPU | CPU | | `local_allreduce_cpu` | CPU | all devices | | `local_allreduce_device` | a device | all devices | +They produce (almost) the same results, but may vary on speed. + - `local_update_cpu`, gradients are first copied to main memory, next averaged on CPU, and then update the weight on CPU. It is suitable when the average size of weights are not large and there are a large number of weight. For example the @@ -88,5 +93,6 @@ situation. But they are different on both semantic and performance. between all workers, and therefore may harm the system performance. - `dist_async`: the gradient is sent to the servers, and the weight is updated - there. The weights a worker has may be stale. - (TODO) make the max delay be settable? + there. The weights a worker has may be stale. This loose data consistency + model reduces the machine synchronization cost and therefore could improve the + system performance. But it may harm the convergence speed. From 52b7cf10422d47244ee72c83877a06ec1f30e5cc Mon Sep 17 00:00:00 2001 From: muli Date: Sun, 11 Oct 2015 05:28:32 +0000 Subject: [PATCH 17/19] [kvstore] bug fix --- python/mxnet/context.py | 4 ++-- python/mxnet/model.py | 11 ++++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/mxnet/context.py b/python/mxnet/context.py index 1ed3dae5fb23..5d5a5f40ab21 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -29,8 +29,8 @@ class Context(object): """ # static class variable default_ctx = None - devtype2str = {1: 'cpu', 2: 'gpu'} - devstr2type = {'cpu': 1, 'gpu': 2} + devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu'} + devstr2type = {'cpu': 1, 'gpu': 2, 'cpu': 3} def __init__(self, device_type, device_id=0): if isinstance(device_type, Context): self.device_typeid = device_type.device_typeid diff --git a/python/mxnet/model.py b/python/mxnet/model.py index ea926500dfba..b365a0525116 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -241,9 +241,9 @@ def _train_multi_device(symbol, ctx, input_shape, raise TypeError('kvstore must be either KVStore or str') # detect whether or not update weight on kvstore - update_on_kvstore = False - if kv and 'local_allreduce' in kv.get_type(): - update_on_kvstore = True + update_on_kvstore = True + if not kv or 'local_allreduce' in kv.get_type(): + update_on_kvstore = False # init optimizer before give it to kv or get_updater optimizer.begin_round(begin_round) @@ -308,8 +308,9 @@ def _train_multi_device(symbol, ctx, input_shape, # pull back the sum gradients, to the same locations. kv.pull(index, grad_list, priority=-index) if not update_on_kvstore: - for w, g in zip(arg_list, grad_list): - updater(index, g, w) + for k, p in enumerate(zip(arg_list, grad_list)): + w, g = p + updater(index*num_device+k, g, w) nbatch += 1 # epoch callback (for print purpose) From 709ad2ba704e8f9c79bf42d40b42ed2d68c3676b Mon Sep 17 00:00:00 2001 From: muli Date: Sun, 11 Oct 2015 05:32:37 +0000 Subject: [PATCH 18/19] [python] add cpu_pinned on context --- python/mxnet/context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/context.py b/python/mxnet/context.py index 5d5a5f40ab21..b35d910407a8 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -29,8 +29,8 @@ class Context(object): """ # static class variable default_ctx = None - devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu'} - devstr2type = {'cpu': 1, 'gpu': 2, 'cpu': 3} + devtype2str = {1: 'cpu', 2: 'gpu', 3: 'cpu_pinned'} + devstr2type = {'cpu': 1, 'gpu': 2, 'cpu_pinned': 3} def __init__(self, device_type, device_id=0): if isinstance(device_type, Context): self.device_typeid = device_type.device_typeid From 0ccab70944aff0b18a435a69f4edfcb9a78f8c46 Mon Sep 17 00:00:00 2001 From: muli Date: Sun, 11 Oct 2015 16:37:55 -0400 Subject: [PATCH 19/19] [kvstore] update --- doc/developer-guide/multi_node.md | 18 +++---- include/mxnet/kvstore.h | 2 +- python/mxnet/kvstore.py | 52 +++++++++---------- python/mxnet/kvstore_server.py | 2 +- python/mxnet/model.py | 72 ++++++++++++++++++--------- tests/python/unittest/test_kvstore.py | 2 +- 6 files changed, 85 insertions(+), 63 deletions(-) diff --git a/doc/developer-guide/multi_node.md b/doc/developer-guide/multi_node.md index 07259417aaf7..3f43636b41dd 100644 --- a/doc/developer-guide/multi_node.md +++ b/doc/developer-guide/multi_node.md @@ -18,20 +18,16 @@ MXNet uses a two-level *parameter server* for data synchronization. ## KVStore MXNet implemented the two-level parameter server in class *KVStore*. We -currently provide the following types: +currently provide the following three types. Given the batch size *b*: -| kvstore type | multi-devices | multi-workers | #ex per device | #ex per update | max delay | +| kvstore type | #devices | #workers | #ex per device | #ex per update | max delay | | :--- | --- | --- | --- | --- | --- | -| `none` | no | no | *b* | *b* | *0* | -| `local` | yes | no | *b / k* | *b* | *0* | -| `dist_sync` | yes | yes | *b / k* | *b × n* | *0* | -| `dist_async` | yes | yes | *b / k* | *b* | inf | +| `local` | *k* | 1 | *b / k* | *b* | *0* | +| `dist_sync` | *k* | *n* | *b / k* | *b × n* | *0* | +| `dist_async` | *k* | *n* | *b / k* | *b* | inf | -where - -- **n** : the number of workers (often mean machines) -- **k** : the number of devices used on a worker (could vary for different workers) -- **b** : the batch size set by users +where the number of devices *k* used on a worker could vary for different +workers. And - **number examples per update** : for each update, the number of examples used to calculate the averaged gradients. Often the larger, the slower the convergence. diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index e97fbf06035f..59d3c2390c7d 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -39,7 +39,7 @@ class KVStore { /** * \brief return the type */ - inline std::string type() { return type_; } + inline const std::string& type() { return type_; } /*! * \brief Initialize a list of key-value pair to the store. diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 98eda59a9b15..4e2ab83d0a8d 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -6,7 +6,7 @@ import pickle from .ndarray import NDArray from .base import _LIB -from .base import check_call, c_array, c_str, string_types, mx_uint +from .base import check_call, c_array, c_str, string_types, mx_uint, py_str from .base import NDArrayHandle, KVStoreHandle from . import optimizer as opt @@ -63,25 +63,12 @@ def __init__(self, handle): def __del__(self): check_call(_LIB.MXKVStoreFree(self.handle)) - def get_type(self): - """Get the type of this kvstore - - Returns - ------- - type : str - the string type - """ - kv_type = ctypes.c_char_p() - check_call(_LIB.MXKVStoreGetType(self.handle, ctypes.byref(kv_type))) - return kv_type.value - - def init(self, key, value): """ Initialize a single or a sequence of key-value pairs into the store. For each key, one must init it before push and pull. - Only worker 0's (get_rank() == 0) data are used. + Only worker 0's (rank == 0) data are used. This function returns after data have been initialized successfully @@ -108,7 +95,7 @@ def init(self, key, value): >>> keys = [5, 7, 9] >>> kv.init(keys, [mx.nd.ones(shape)]*len(keys)) """ - if (self.get_rank() == 0): + if (self.rank == 0): ckeys, cvals = _ctype_key_value(key, value) check_call(_LIB.MXKVStoreInit( self.handle, mx_uint(len(ckeys)), ckeys, cvals)) @@ -279,7 +266,21 @@ def set_optimizer(self, optimizer): else: self._set_updater(opt.get_updater(optimizer)) - def get_rank(self): + @property + def type(self): + """Get the type of this kvstore + + Returns + ------- + type : str + the string type + """ + kv_type = ctypes.c_char_p() + check_call(_LIB.MXKVStoreGetType(self.handle, ctypes.byref(kv_type))) + return py_str(kv_type.value) + + @property + def rank(self): """Get the rank of this worker node Returns @@ -291,7 +292,8 @@ def get_rank(self): check_call(_LIB.MXKVStoreGetRank(self.handle, ctypes.byref(rank))) return rank.value - def get_num_workers(self): + @property + def num_workers(self): """Get the number of worker ndoes Returns @@ -345,17 +347,17 @@ def _barrier(self): pulling, we can place a barrier to guarantee that the initialization is finished. - The following codes run on n machines in parallel - - >>> if kv.get_rank() == 0: - ... kv.init(keys, values); - ... kv.barrier() - ... kv.pull(keys, out = values); - But note that, this functions only blocks the main thread of workers until all of them are reached this point. It doesn't guarantee that all operations issued before are actually finished, such as \ref Push and \ref Pull. In that case, we need to call \ref Wait or \ref WaitAll + + The following codes implement a BSP model + + >>> kv.push(keys, values) + ... kv._wait(keys) + ... kv._barrier() + ... kv.pull(keys, out = values); """ check_call(_LIB.MXKVStoreBarrier(self.handle)) diff --git a/python/mxnet/kvstore_server.py b/python/mxnet/kvstore_server.py index 85f9c26ef5ae..dc4356925b40 100644 --- a/python/mxnet/kvstore_server.py +++ b/python/mxnet/kvstore_server.py @@ -31,7 +31,7 @@ def server_controller(cmd_id, cmd_body): self.kvstore.set_optimizer(optimizer) else: print ("server %d, unknown command (%d, %s)" % ( - self.kvstore.get_rank(), cmd_id, cmd_body)) + self.kvstore.rank, cmd_id, cmd_body)) return server_controller def run(self): diff --git a/python/mxnet/model.py b/python/mxnet/model.py index b365a0525116..a8430e55c6eb 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -123,6 +123,50 @@ def _split_input_slice(input_shape, num_split): shapes.append(tuple(s)) return (slices, shapes) +def _create_kvstore(kvstore, num_device, arg_params): + """Create kvstore + + This function select and create a proper kvstore if given the kvstore type + + Parameters + ---------- + + kvstore : KVStore or str + The kvstore + + num_device : int + The number of devices + + arg_params : dict of str to NDArray + Model parameter, dict of name to NDArray of net's weights. + """ + + if isinstance(kvstore, kvs.KVStore): + kv = kvstore + elif isinstance(kvstore, str): + # create kvstore using the string type + if num_device is 1 and 'dist' not in kvstore: + # no need to use kv for single device and single machine + kv = None + else: + if kvstore is 'local': + # automatically select a proper local + max_size = max(np.prod(param.shape) for param in arg_params.values()) + if max_size < 1024 * 1024 * 16: + kvstore = 'local_update_cpu' + else: + kvstore = 'local_allreduce_cpu' + logging.info('Auto-select kvstore type = %s', kvstore) + kv = kvs.create(kvstore) + else: + raise TypeError('kvstore must be either KVStore or str') + + # detect whether or not update weight on kvstore + update_on_kvstore = True + if not kv or 'local_allreduce' in kv.type: + update_on_kvstore = False + + return (kv, update_on_kvstore) def _train_multi_device(symbol, ctx, input_shape, arg_params, aux_params, @@ -220,30 +264,7 @@ def _train_multi_device(symbol, ctx, input_shape, texec.copy_params_from(arg_params, aux_params) # create kvstore - if isinstance(kvstore, kvs.KVStore): - kv = kvstore - elif isinstance(kvstore, str): - # create kvstore using the string type - if num_device is 1 and 'dist' not in kvstore: - # no need to use kv for single device and single machine - kv = None - else: - if kvstore is 'local': - # automatically select a proper local - max_size = max(np.prod(param.shape) for param in arg_params.values()) - if max_size < 1024 * 1024 * 16: - kvstore = 'local_update_cpu' - else: - kvstore = 'local_allreduce_cpu' - logging.info('Auto-select kvstore type = %s', kvstore) - kv = kvs.create(kvstore) - else: - raise TypeError('kvstore must be either KVStore or str') - - # detect whether or not update weight on kvstore - update_on_kvstore = True - if not kv or 'local_allreduce' in kv.get_type(): - update_on_kvstore = False + (kv, update_on_kvstore) = _create_kvstore(kvstore, num_device, arg_params) # init optimizer before give it to kv or get_updater optimizer.begin_round(begin_round) @@ -309,6 +330,9 @@ def _train_multi_device(symbol, ctx, input_shape, kv.pull(index, grad_list, priority=-index) if not update_on_kvstore: for k, p in enumerate(zip(arg_list, grad_list)): + # faked an index here, to make optimizer create diff + # state for the same index but on diff devs, TODO(mli) + # use a better solution latter w, g = p updater(index*num_device+k, g, w) diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index b76b9dd536a3..77439677320f 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -107,7 +107,7 @@ def test_updater(dev = 'cpu'): def test_get_type(): kvtype = 'local_allreduce_cpu' kv = mx.kv.create(kvtype) - assert kv.get_type() == kvtype + assert kv.type == kvtype if __name__ == '__main__': test_get_type()