From d73a486093c320232132e5c740329eb113ff21ab Mon Sep 17 00:00:00 2001 From: Mu Li Date: Sun, 13 Sep 2015 16:31:36 -0400 Subject: [PATCH 01/11] rename executor.heads() to executor.outputs in example --- example/mnist/mlp_multi_gpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/mnist/mlp_multi_gpu.py b/example/mnist/mlp_multi_gpu.py index 9ada84668727..9c0b59298879 100644 --- a/example/mnist/mlp_multi_gpu.py +++ b/example/mnist/mlp_multi_gpu.py @@ -54,7 +54,7 @@ def updater(key, grad, weight): # create executors for devices executors = [mlp.bind(devs[d], params[d], grads[d]) for d in range(num_devs)] -forward_out = [mx.nd.zeros(e.heads()[0].shape) for e in executors] +forward_out = [mx.nd.zeros(e.outputs[0].shape) for e in executors] # data reader get_data.GetMNIST_ubyte() @@ -97,7 +97,7 @@ def run_sgd(): params[d][param_names.index('mlp_label')][:] = label[rows] executors[d].forward() - executors[d].heads()[0].copyto(forward_out[d]) + executors[d].outputs[0].copyto(forward_out[d]) executors[d].backward([forward_out[d]]) # push gradient @@ -123,7 +123,7 @@ def run_sgd(): # eval for d in range(num_devs): - val_acc += cal_acc(executors[d].heads()[0].asnumpy(), + val_acc += cal_acc(executors[d].outputs[0].asnumpy(), label[batch_splits[d]]) val_count += 1 From d963d7631cba93579d0a15bcf2ba2e8066058f81 Mon Sep 17 00:00:00 2001 From: Mu Li Date: Sun, 13 Sep 2015 16:46:59 -0400 Subject: [PATCH 02/11] disable printing all ops in graphc_executor --- src/symbol/graph_executor.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h index af752776df48..a97b34510b2e 100644 --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -42,9 +42,9 @@ class GraphExecutor : public Executor { this->InitDataEntryMemory(); this->InitOpNodes(); // TODO(bing): remove me when things are OK - LOG(INFO) << "-----Execution memory plan-----\n" - << DebugStr() << '\n' - << "------------------------------\n"; + // LOG(INFO) << "-----Execution memory plan-----\n" + // << DebugStr() << '\n' + // << "------------------------------\n"; } protected: From 61432b0e231d8c433dcc8875c45b01ff930a9bfc Mon Sep 17 00:00:00 2001 From: Mu Li Date: Sun, 13 Sep 2015 16:47:41 -0400 Subject: [PATCH 03/11] measure time in mlp_multi_gpu.py --- example/mnist/mlp_multi_gpu.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/example/mnist/mlp_multi_gpu.py b/example/mnist/mlp_multi_gpu.py index 9c0b59298879..8260cdb21f5d 100644 --- a/example/mnist/mlp_multi_gpu.py +++ b/example/mnist/mlp_multi_gpu.py @@ -5,6 +5,7 @@ import sys sys.path.append("../../tests/python") import get_data +import time # use multiple devices num_devs = 4 @@ -12,7 +13,6 @@ mx.kvstore.start() # symbol net -batch_size = 100 data = mx.symbol.Variable('data') fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") @@ -77,6 +77,7 @@ def run_sgd(): num_epochs = 9 for epoch in range(num_epochs): + start = time.time() print "Epoch %d" % epoch train_count = 0.0 train_acc = 0.0 @@ -127,8 +128,10 @@ def run_sgd(): label[batch_splits[d]]) val_count += 1 - print("Train Acc: ", train_acc / train_count) - print("Valid Acc: ", val_acc / val_count) + print("Train Acc: %g, Valid Acc: %g, time: %g" % ( + train_acc / train_count, + val_acc / val_count, + time.time() - start)) train_dataiter.reset() val_dataiter.reset() From 5b57755cb21f08131b6bbdb6ff636d4b0234c3f1 Mon Sep 17 00:00:00 2001 From: Mu Li Date: Sun, 13 Sep 2015 19:56:17 -0400 Subject: [PATCH 04/11] start mult-gpu cifa --- example/cifar10/cifa10_multi_gpus.py | 189 +++++++++++++++++++++++++++ python/mxnet/__init__.py | 1 + python/mxnet/updater.py | 17 +++ 3 files changed, 207 insertions(+) create mode 100644 example/cifar10/cifa10_multi_gpus.py create mode 100644 python/mxnet/updater.py diff --git a/example/cifar10/cifa10_multi_gpus.py b/example/cifar10/cifa10_multi_gpus.py new file mode 100644 index 000000000000..2f643b6e285e --- /dev/null +++ b/example/cifar10/cifa10_multi_gpus.py @@ -0,0 +1,189 @@ +# pylint: skip-file +import numpy as np +import mxnet as mx +import copy +import sys +sys.path.append("../../tests/python") +import get_data +import time + +# use multiple devices +num_devs = 4 +devs = [mx.gpu(i) for i in range(num_devs)] +mx.kvstore.start() + +# define the network +conv_cnt = 1 +concat_cnt = 1 +pool_cnt = 1 + +def ConvFactory(**kwargs): + global conv_cnt + param = copy.copy(kwargs) + act = param["act_type"] + del param["act_type"] + param["workspace"] = 256 + param["name"] = "conv%d" % conv_cnt + conv = mx.symbol.Convolution(**param) + bn = mx.symbol.BatchNorm(data = conv, name="bn%d" % conv_cnt) + relu = mx.symbol.Activation(data = bn, name = "%s%d" % (act, conv_cnt), act_type=act) + conv_cnt += 1 + return relu + +def DownsampleFactory(data, ch_3x3, stride = 2): + global pool_cnt + global concat_cnt + param = {} + # conv 3x3 + param["kernel"] = (3, 3) + param["stride"] = (stride, stride) + param["num_filter"] = ch_3x3 + param["act_type"] = "relu" + param["data"] = data + param["pad"] = (1, 1) + conv3x3 = ConvFactory(**param) + # pool + del param["num_filter"] + del param["act_type"] + del param["pad"] + param["pool_type"] = "max" + param["name"] = "pool%d" % pool_cnt + pool = mx.symbol.Pooling(**param) + pool_cnt += 1 + # concat + concat = mx.symbol.Concat(*[conv3x3, pool], name="concat%d" % concat_cnt) + concat_cnt += 1 + return concat + +def SimpleFactory(data, ch_1x1, ch_3x3): + global concat_cnt + param = {} + # 1x1 + param["kernel"] = (1, 1) + param["num_filter"] = ch_1x1 + param["pad"] = (0, 0) + param["stride"] = (1, 1) + param["act_type"] = "relu" + param["data"] = data + conv1x1 = ConvFactory(**param) + + # 3x3 + param["kernel"] = (3, 3) + param["num_filter"] = ch_3x3 + param["pad"] = (1, 1) + conv3x3 = ConvFactory(**param) + + #concat + concat = mx.symbol.Concat(*[conv1x1, conv3x3], name="concat%d" % concat_cnt) + concat_cnt += 1 + return concat + +def RandomInit(narray): + in_num = narray.shape[1] + out_num = narray.shape[0] + a = np.sqrt(3.0 / (in_num + out_num)) + tmp = mx.nd.array(np.random.uniform(-a, a, narray.shape)) + narray[:] = tmp + +data = mx.symbol.Variable(name="data") +conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu") +in3a = SimpleFactory(conv1, 32, 32) +in3b = SimpleFactory(in3a, 32, 48) +in3c = DownsampleFactory(in3b, 80) +in4a = SimpleFactory(in3c, 112, 48) +in4b = SimpleFactory(in4a, 96, 64) +in4c = SimpleFactory(in4b, 80, 80) +in4d = SimpleFactory(in4c, 48, 96) +in4e = DownsampleFactory(in4d, 96) +in5a = SimpleFactory(in4e, 176, 160) +in5b = SimpleFactory(in5a, 176, 160) +pool = mx.symbol.Pooling(data=in5b, pool_type="avg", kernel=(7,7), name="pool%d" % pool_cnt) +flatten = mx.symbol.Flatten(data=pool, name="flatten1") +fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1") +loss = mx.symbol.Softmax(data=fc, name="sm") + +# define model updater +updater = mx.updater.momentum( + learning_rate = .05, weight_decay = .0001, momentum = 0.9) +mx.kvstore.set_updater(updater) + + +#check data +get_data.GetCifar10() + +train_dataiter = mx.io.ImageRecordIter( + path_imgrec="data/cifar/train.rec", + mean_img="data/cifar/cifar_mean.bin", + rand_crop=True, + rand_mirror=True, + input_shape=(3,28,28), + batch_size=batch_size, + nthread=1) +test_dataiter = mx.io.ImageRecordIter( + path_imgrec="data/cifar/test.rec", + mean_img="data/cifar/cifar_mean.bin", + rand_crop=False, + rand_mirror=False, + input_shape=(3,28,28), + batch_size=batch_size, + nthread=1) + + +def progress(count, total, epoch, toc): + bar_len = 50 + filled_len = int(round(bar_len * count / float(total))) + + percents = round(100.0 * count / float(total), 1) + bar = '=' * filled_len + '-' * (bar_len - filled_len) + tic = time.time() + speed = batch_size / float(tic - toc) + suffix = "Epoch %d, Speed: %.2f pic/sec" % (epoch, speed) + sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', suffix)) + + +def train(): + acc_train = 0. + acc_val = 0. + print("Start training...") + for i in range(epoch): + # train + train_acc = 0.0 + val_acc = 0.0 + train_nbatch = 0 + val_nbatch = 0 + all_train_bacth = round(50000 / float(batch_size) + 1) + for data, label in train_dataiter: + toc = time.time() + label = label.asnumpy().flatten() + tmp_label[:] = label + inputs["data"][:] = data + inputs["sm_label"][:] = tmp_label + executor.forward() + pred[:] = out_narray + train_acc += CalAcc(pred.asnumpy(), label) + train_nbatch += 1 + #executor.backward([out_narray]) + executor.backward() + + for grad, weight, mom in block: + Update(grad, weight, mom) + progress(train_nbatch, all_train_bacth, i, toc) + + # evaluate + for data, label in test_dataiter: + label = label.asnumpy().flatten() + inputs["data"][:] = data + executor.forward() + pred[:] = out_narray + val_acc += CalAcc(pred.asnumpy(), label) + val_nbatch += 1 + acc_train = train_acc / train_nbatch + acc_val = val_acc / val_nbatch + sys.stdout.write('\n') + print("Train Acc: ", train_acc / train_nbatch) + print("Valid Acc: ", val_acc / val_nbatch) + train_dataiter.reset() + test_dataiter.reset() + +if __name__ == "__main__": + train() diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index c591fc29510b..1863fdedc6e1 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -13,6 +13,7 @@ from . import ndarray from . import symbol from . import kvstore +from . import updater from . import io # use mx.nd as short for mx.ndarray from . import ndarray as nd diff --git a/python/mxnet/updater.py b/python/mxnet/updater.py new file mode 100644 index 000000000000..47337cecaa9b --- /dev/null +++ b/python/mxnet/updater.py @@ -0,0 +1,17 @@ +# coding: utf-8 + + +def momentum(learning_rate = .01, weight_decay = 0.0001, momentum=0.9): + """Stochastic Gradient Descent (SGD) updates with momentum + + Parameters + ---------- + """ + + def momentum_update(key, grad, weight): + mom = momentums[key] + mom *= momentum + mom += - learning_rate * (grad + weight_decay * weight) + weight += mom + + return momentum_update From ea98a4c8a28d2949fb166fac65dee34a9a3716c9 Mon Sep 17 00:00:00 2001 From: muli Date: Sun, 13 Sep 2015 21:22:36 -0400 Subject: [PATCH 05/11] update cifa mulgpu --- example/cifar10/cifa10_multi_gpus.py | 119 ++++++++++++++++++++------- 1 file changed, 87 insertions(+), 32 deletions(-) diff --git a/example/cifar10/cifa10_multi_gpus.py b/example/cifar10/cifa10_multi_gpus.py index 2f643b6e285e..7d7cd15b58a1 100644 --- a/example/cifar10/cifa10_multi_gpus.py +++ b/example/cifar10/cifa10_multi_gpus.py @@ -3,14 +3,14 @@ import mxnet as mx import copy import sys -sys.path.append("../../tests/python") +sys.path.append("../../tests/python/common") import get_data import time # use multiple devices num_devs = 4 -devs = [mx.gpu(i) for i in range(num_devs)] -mx.kvstore.start() +devs = [mx.cpu(i) for i in range(num_devs)] +# mx.kvstore.start() # define the network conv_cnt = 1 @@ -105,12 +105,42 @@ def RandomInit(narray): # define model updater updater = mx.updater.momentum( learning_rate = .05, weight_decay = .0001, momentum = 0.9) -mx.kvstore.set_updater(updater) - - -#check data +# mx.kvstore.set_updater(updater) + +# infer shape +batch_size = 128 +batch_size -= (batch_size % num_devs) +data_shape = (batch_size, 3, 28, 28) + +# create executors for devices +executors = [loss.simple_bind(d, data = mx.nd.empty(data_shape, d)) for d in devs] + +# find the params needed to be synchronized between devices +param_names = loss.list_arguments() +sync_indices = [index for index, name in enumerate(param_names) + if "weight" in name or "bias" in name] +sync_weights = [[e.list_arguments()[0][i] for e in executors] for i in sync_indices] +sync_grads = [[e.list_arguments()[1][i] for e in executors] for i in sync_indices] + + +# init global shared model +for idx in sync_indices: + shape = sync_weights[0][idx].shape + val = mx.nd.zeros(shape) + if "weight" in param_names[idx]: + val[:] = np.random.uniform(-0.1, 0.1, shape) + mx.kvstore.init(idx, val) + +# init local variables +for e in executors: + for idx, data in enumerate(e.list_arguments()[0]) + if "gamma" in param_names[idx]: + data = 1.0 + if "beta" in param_names[idx]: + data = 0.0 + +# data reader get_data.GetCifar10() - train_dataiter = mx.io.ImageRecordIter( path_imgrec="data/cifar/train.rec", mean_img="data/cifar/cifar_mean.bin", @@ -119,7 +149,7 @@ def RandomInit(narray): input_shape=(3,28,28), batch_size=batch_size, nthread=1) -test_dataiter = mx.io.ImageRecordIter( +val_dataiter = mx.io.ImageRecordIter( path_imgrec="data/cifar/test.rec", mean_img="data/cifar/cifar_mean.bin", rand_crop=False, @@ -140,11 +170,19 @@ def progress(count, total, epoch, toc): suffix = "Epoch %d, Speed: %.2f pic/sec" % (epoch, speed) sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', suffix)) +def cal_acc(out, label): + pred = np.argmax(out, axis=1) + return np.sum(pred == label) * 1.0 / out.shape[0] def train(): acc_train = 0. acc_val = 0. + k = batch_size / num_devs + batch_splits = [range(d*k, (d+1)*k) for d in range(num_devs)] print("Start training...") + data_in = [e.list_arguments()[0][param_names.index('data')] for e in executors] + label_in = [e.list_arguments()[0][param_names.index('loss_label')] for e in executors] + for i in range(epoch): # train train_acc = 0.0 @@ -152,36 +190,53 @@ def train(): train_nbatch = 0 val_nbatch = 0 all_train_bacth = round(50000 / float(batch_size) + 1) + for data, label in train_dataiter: - toc = time.time() + # pull weight + mx.kvstore.pull(sync_indices, out = sync_weights) + + # forward and backword + data = data.asnumpy() label = label.asnumpy().flatten() - tmp_label[:] = label - inputs["data"][:] = data - inputs["sm_label"][:] = tmp_label - executor.forward() - pred[:] = out_narray - train_acc += CalAcc(pred.asnumpy(), label) - train_nbatch += 1 - #executor.backward([out_narray]) - executor.backward() - - for grad, weight, mom in block: - Update(grad, weight, mom) - progress(train_nbatch, all_train_bacth, i, toc) + for d in range(num_devs): + rows = batch_splits[d] + data_in[d] = data[rows, :] + label_in[d] = label[rows] + executors[d].forward() + executors[d].backward() + + # push gradient + mx.kvstore.push(sync_indices, sync_grads) + + # evaluate + for d in range(num_devs): + train_acc += cal_acc(executors[d].outputs[0].asnumpy(), + label[batch_splits[d]]) + train_count += 1 + + progress(train_count, all_train_bacth, i, toc) # evaluate for data, label in test_dataiter: + # forward + data = data.asnumpy() label = label.asnumpy().flatten() - inputs["data"][:] = data - executor.forward() - pred[:] = out_narray - val_acc += CalAcc(pred.asnumpy(), label) - val_nbatch += 1 - acc_train = train_acc / train_nbatch - acc_val = val_acc / val_nbatch + for d in range(num_devs): + rows = batch_splits[d] + data_in[d] = data[rows,:] + executors[d].forward() + + # eval + for d in range(num_devs): + val_acc += cal_acc(executors[d].outputs[0].asnumpy(), + label[batch_splits[d]]) + val_count += 1 + sys.stdout.write('\n') - print("Train Acc: ", train_acc / train_nbatch) - print("Valid Acc: ", val_acc / val_nbatch) + + print("Train Acc: %g, Valid Acc: %g" % ( + train_acc / train_count, + val_acc / val_count)) train_dataiter.reset() test_dataiter.reset() From 6482c5cc4aa02c9937aa225da07a5742d0301e1e Mon Sep 17 00:00:00 2001 From: muli Date: Mon, 14 Sep 2015 12:50:25 -0400 Subject: [PATCH 06/11] update cifar10 mulgpu --- example/cifar10/cifa10_multi_gpus.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/example/cifar10/cifa10_multi_gpus.py b/example/cifar10/cifa10_multi_gpus.py index a84967e0c710..6018dcf82776 100644 --- a/example/cifar10/cifa10_multi_gpus.py +++ b/example/cifar10/cifa10_multi_gpus.py @@ -100,7 +100,7 @@ def RandomInit(narray): pool = mx.symbol.Pooling(data=in5b, pool_type="avg", kernel=(7,7), name="pool%d" % pool_cnt) flatten = mx.symbol.Flatten(data=pool, name="flatten1") fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1") -loss = mx.symbol.Softmax(data=fc, name="sm") +loss = mx.symbol.Softmax(data=fc, name="loss") # define model updater updater = mx.updater.momentum( @@ -113,13 +113,12 @@ def RandomInit(narray): data_shape = (batch_size, 3, 28, 28) # create executors for devices + d = devs[0] loss.simple_bind(d, data = mx.nd.empty(data_shape, d)) loss.simple_bind(d, data = mx.nd.empty(data_shape, d)) -executors = [] -for d in devs: - executors.append(loss.simple_bind(d, data = mx.nd.empty(data_shape, d))) +executors = [loss.simple_bind(d, data = mx.nd.empty(data_shape, d)) for d in devs] # d = devs[0] # ex = loss.simple_bind(d, data = mx.nd.empty(data_shape, d)) @@ -136,8 +135,9 @@ def RandomInit(narray): # init global shared model +weights = executors[0].list_arguments()[0] for idx in sync_indices: - shape = sync_weights[0][idx].shape + shape = weights[idx].shape val = mx.nd.zeros(shape) if "weight" in param_names[idx]: val[:] = np.random.uniform(-0.1, 0.1, shape) @@ -187,6 +187,7 @@ def cal_acc(out, label): return np.sum(pred == label) * 1.0 / out.shape[0] def train(): + epoch = 10 acc_train = 0. acc_val = 0. k = batch_size / num_devs From 89fd652a5c6ad91fac43818c6582208594cfbc48 Mon Sep 17 00:00:00 2001 From: Mu Li Date: Mon, 14 Sep 2015 14:35:20 -0400 Subject: [PATCH 07/11] simplify mlp mulgpu --- example/mnist/mlp_multi_gpu.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/example/mnist/mlp_multi_gpu.py b/example/mnist/mlp_multi_gpu.py index 8260cdb21f5d..69b9426edbd3 100644 --- a/example/mnist/mlp_multi_gpu.py +++ b/example/mnist/mlp_multi_gpu.py @@ -54,7 +54,6 @@ def updater(key, grad, weight): # create executors for devices executors = [mlp.bind(devs[d], params[d], grads[d]) for d in range(num_devs)] -forward_out = [mx.nd.zeros(e.outputs[0].shape) for e in executors] # data reader get_data.GetMNIST_ubyte() @@ -98,16 +97,14 @@ def run_sgd(): params[d][param_names.index('mlp_label')][:] = label[rows] executors[d].forward() - executors[d].outputs[0].copyto(forward_out[d]) - executors[d].backward([forward_out[d]]) - + executors[d].backward() # push gradient for idx in sync_indices: mx.kvstore.push(idx, [g[idx] for g in grads]) # eval for d in range(num_devs): - train_acc += cal_acc(forward_out[d].asnumpy(), + train_acc += cal_acc(executors[d].outputs[0].asnumpy(), label[batch_splits[d]]) train_count += 1 From b696009efc169e961b194187517ae67f36357980 Mon Sep 17 00:00:00 2001 From: Mu Li Date: Mon, 14 Sep 2015 15:34:15 -0400 Subject: [PATCH 08/11] bug fix in cifar mulgpu --- example/cifar10/cifa10_multi_gpus.py | 116 ++++++++++++--------------- example/mnist/mlp_multi_gpu.py | 1 - python/mxnet/kvstore.py | 14 +++- python/mxnet/updater.py | 8 +- 4 files changed, 68 insertions(+), 71 deletions(-) diff --git a/example/cifar10/cifa10_multi_gpus.py b/example/cifar10/cifa10_multi_gpus.py index 6018dcf82776..47e6c47f9073 100644 --- a/example/cifar10/cifa10_multi_gpus.py +++ b/example/cifar10/cifa10_multi_gpus.py @@ -8,8 +8,8 @@ import time # use multiple devices -num_devs = 2 -devs = [mx.cpu(i) for i in range(num_devs)] +num_devs = 4 +devs = [mx.gpu(i) for i in range(num_devs)] mx.kvstore.start() # define the network @@ -78,13 +78,6 @@ def SimpleFactory(data, ch_1x1, ch_3x3): concat_cnt += 1 return concat -def RandomInit(narray): - in_num = narray.shape[1] - out_num = narray.shape[0] - a = np.sqrt(3.0 / (in_num + out_num)) - tmp = mx.nd.array(np.random.uniform(-a, a, narray.shape)) - narray[:] = tmp - data = mx.symbol.Variable(name="data") conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu") in3a = SimpleFactory(conv1, 32, 32) @@ -105,80 +98,67 @@ def RandomInit(narray): # define model updater updater = mx.updater.momentum( learning_rate = .05, weight_decay = .0001, momentum = 0.9) -# mx.kvstore.set_updater(updater) +mx.kvstore.set_updater(updater) # infer shape -batch_size = 128 +batch_size = 196 batch_size -= (batch_size % num_devs) -data_shape = (batch_size, 3, 28, 28) +data_shape = (batch_size / num_devs, 3, 28, 28) # create executors for devices - -d = devs[0] -loss.simple_bind(d, data = mx.nd.empty(data_shape, d)) -loss.simple_bind(d, data = mx.nd.empty(data_shape, d)) - executors = [loss.simple_bind(d, data = mx.nd.empty(data_shape, d)) for d in devs] -# d = devs[0] -# ex = loss.simple_bind(d, data = mx.nd.empty(data_shape, d)) -# print ex - -# executors = [loss.simple_bind(d, data = mx.nd.empty(data_shape, d)) for d in devs] - # find the params needed to be synchronized between devices param_names = loss.list_arguments() +sync_prefix = ["weight", "bias", "beta", "gamma"] sync_indices = [index for index, name in enumerate(param_names) - if "weight" in name or "bias" in name] + if any(prefix in name for prefix in sync_prefix)] + sync_weights = [[e.list_arguments()[0][i] for e in executors] for i in sync_indices] sync_grads = [[e.list_arguments()[1][i] for e in executors] for i in sync_indices] -# init global shared model +# init model weights = executors[0].list_arguments()[0] for idx in sync_indices: shape = weights[idx].shape val = mx.nd.zeros(shape) if "weight" in param_names[idx]: val[:] = np.random.uniform(-0.1, 0.1, shape) + elif "gamma" in param_names[idx]: + val[:] = 1.0 mx.kvstore.init(idx, val) -# init local variables -for e in executors: - for idx, data in enumerate(e.list_arguments()[0]): - if "gamma" in param_names[idx]: - data = 1.0 - if "beta" in param_names[idx]: - data = 0.0 - # data reader get_data.GetCifar10() + train_dataiter = mx.io.ImageRecordIter( - path_imgrec="data/cifar/train.rec", - mean_img="data/cifar/cifar_mean.bin", - rand_crop=True, - rand_mirror=True, - input_shape=(3,28,28), - batch_size=batch_size, - nthread=1) + path_imgrec="data/cifar/train.rec", + mean_img="data/cifar/cifar_mean.bin", + rand_crop=True, + rand_mirror=True, + shuffle=True, + input_shape=(3,28,28), + batch_size=batch_size, + nthread=1) + val_dataiter = mx.io.ImageRecordIter( - path_imgrec="data/cifar/test.rec", - mean_img="data/cifar/cifar_mean.bin", - rand_crop=False, - rand_mirror=False, - input_shape=(3,28,28), - batch_size=batch_size, - nthread=1) + path_imgrec="data/cifar/test.rec", + mean_img="data/cifar/cifar_mean.bin", + rand_crop=False, + rand_mirror=False, + input_shape=(3,28,28), + batch_size=batch_size, + nthread=1) -def progress(count, total, epoch, toc): +def progress(count, total, epoch, tic): bar_len = 50 filled_len = int(round(bar_len * count / float(total))) - percents = round(100.0 * count / float(total), 1) bar = '=' * filled_len + '-' * (bar_len - filled_len) - tic = time.time() - speed = batch_size / float(tic - toc) + toc = time.time() + speed = batch_size / float(toc - tic) suffix = "Epoch %d, Speed: %.2f pic/sec" % (epoch, speed) sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', suffix)) @@ -187,7 +167,7 @@ def cal_acc(out, label): return np.sum(pred == label) * 1.0 / out.shape[0] def train(): - epoch = 10 + epoch = 1 acc_train = 0. acc_val = 0. k = batch_size / num_devs @@ -198,13 +178,15 @@ def train(): for i in range(epoch): # train + start = time.time() train_acc = 0.0 val_acc = 0.0 - train_nbatch = 0 - val_nbatch = 0 - all_train_bacth = round(50000 / float(batch_size) + 1) + train_count = 0 + val_count = 0 + all_train_bacth = round(50000 / float(batch_size/num_devs) + 1) for data, label in train_dataiter: + tic = time.time() # pull weight mx.kvstore.pull(sync_indices, out = sync_weights) @@ -213,11 +195,16 @@ def train(): label = label.asnumpy().flatten() for d in range(num_devs): rows = batch_splits[d] - data_in[d] = data[rows, :] - label_in[d] = label[rows] + data_in[d][:] = data[rows, :] + label_in[d][:] = label[rows] executors[d].forward() executors[d].backward() + # normalize gradient + for grads in sync_grads: + for g in grads: + g /= batch_size + # push gradient mx.kvstore.push(sync_indices, sync_grads) @@ -227,16 +214,16 @@ def train(): label[batch_splits[d]]) train_count += 1 - progress(train_count, all_train_bacth, i, toc) + progress(train_count, all_train_bacth, i, tic) # evaluate - for data, label in test_dataiter: + for data, label in val_dataiter: # forward data = data.asnumpy() label = label.asnumpy().flatten() for d in range(num_devs): rows = batch_splits[d] - data_in[d] = data[rows,:] + data_in[d][:] = data[rows,:] executors[d].forward() # eval @@ -247,12 +234,13 @@ def train(): sys.stdout.write('\n') - print("Train Acc: %g, Valid Acc: %g" % ( + print("Train Acc: %g, Valid Acc: %g, Time: %g sec" % ( train_acc / train_count, - val_acc / val_count)) + val_acc / val_count, + time.time() - start)) + train_dataiter.reset() - test_dataiter.reset() + val_dataiter.reset() - mx.kvstore.stop() if __name__ == "__main__": train() diff --git a/example/mnist/mlp_multi_gpu.py b/example/mnist/mlp_multi_gpu.py index 69b9426edbd3..222696d92cc1 100644 --- a/example/mnist/mlp_multi_gpu.py +++ b/example/mnist/mlp_multi_gpu.py @@ -134,4 +134,3 @@ def run_sgd(): if __name__ == "__main__": run_sgd() - mx.kvstore.stop() diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 1132ac89e62e..426962751039 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -6,6 +6,7 @@ from .ndarray import NDArray from .base import _LIB from .base import check_call, c_array, NDArrayHandle +import atexit def _ctype_key_value(keys, vals): """parse key-value args into ctype""" @@ -34,10 +35,6 @@ def start(): """start kvstore""" check_call(_LIB.MXKVStoreStart()) -def stop(): - """ Stop kvstore """ - check_call(_LIB.MXKVStoreStop()) - def init(key, value): """ Initialize a list of key-value pairs @@ -110,3 +107,12 @@ def updater(recv, local): global _updater_func _updater_func = _updater_proto(_updater_wrapper(updater)) check_call(_LIB.MXKVStoreSetUpdater(_updater_func)) + +def stop(): + """ Stop kvstore """ + check_call(_LIB.MXKVStoreStop()) + # need to clear _updater_func before _LIB + global _updater_func + _updater_func = None + +atexit.register(stop) diff --git a/python/mxnet/updater.py b/python/mxnet/updater.py index 47337cecaa9b..57702e01c730 100644 --- a/python/mxnet/updater.py +++ b/python/mxnet/updater.py @@ -1,5 +1,6 @@ # coding: utf-8 - +from __future__ import absolute_import +from .ndarray import zeros def momentum(learning_rate = .01, weight_decay = 0.0001, momentum=0.9): """Stochastic Gradient Descent (SGD) updates with momentum @@ -7,8 +8,11 @@ def momentum(learning_rate = .01, weight_decay = 0.0001, momentum=0.9): Parameters ---------- """ - + momentums = {} def momentum_update(key, grad, weight): + # weight += - learning_rate * (grad + weight_decay * weight) + if not momentums.has_key(key): + momentums[key] = zeros(grad.shape) mom = momentums[key] mom *= momentum mom += - learning_rate * (grad + weight_decay * weight) From 63e5775f3740e0e6b48379a8575ef003e26a6667 Mon Sep 17 00:00:00 2001 From: Mu Li Date: Mon, 14 Sep 2015 15:40:49 -0400 Subject: [PATCH 09/11] tiny changes --- .../cifar10/{cifa10_multi_gpus.py => cifar10_multi_gpus.py} | 2 +- include/mxnet/kvstore.h | 4 +++- src/kvstore/kvstore.cc | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) rename example/cifar10/{cifa10_multi_gpus.py => cifar10_multi_gpus.py} (99%) diff --git a/example/cifar10/cifa10_multi_gpus.py b/example/cifar10/cifar10_multi_gpus.py similarity index 99% rename from example/cifar10/cifa10_multi_gpus.py rename to example/cifar10/cifar10_multi_gpus.py index 47e6c47f9073..dd4a57dcd3ee 100644 --- a/example/cifar10/cifa10_multi_gpus.py +++ b/example/cifar10/cifar10_multi_gpus.py @@ -167,7 +167,7 @@ def cal_acc(out, label): return np.sum(pred == label) * 1.0 / out.shape[0] def train(): - epoch = 1 + epoch = 7 acc_train = 0. acc_val = 0. k = batch_size / num_devs diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index 0acb05f0fa9d..5f6b07680b92 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -40,7 +40,9 @@ class KVStore { * * clear all key-value pairs stored, updater, and devices binded */ - virtual void Stop() { get_impl()->Stop(); delete impl_; impl_ = NULL; } + virtual void Stop() { + if (impl_) { impl_->Stop(); delete impl_; impl_ = NULL; } + } /** * \brief Initialize a list of key-value pair to the store. diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc index 55fd286f7d8e..bcf2ca2b7741 100644 --- a/src/kvstore/kvstore.cc +++ b/src/kvstore/kvstore.cc @@ -11,7 +11,7 @@ namespace mxnet { void KVStore::Start() { - CHECK(impl_ == NULL) << "double initialization, call Stop() first"; + if (impl_ != NULL) Stop(); char* num_worker = getenv("DMLC_NUM_WORKER"); if (num_worker == NULL || atoi(num_worker) == 1) { impl_ = new KVStoreLocal(); From 6c4dc5786b39376d2196b3c68ff471263a031814 Mon Sep 17 00:00:00 2001 From: muli Date: Mon, 14 Sep 2015 15:48:19 -0400 Subject: [PATCH 10/11] fix lint --- example/cifar10/cifar10_multi_gpus.py | 21 ++++++++++++++++++--- python/mxnet/__init__.py | 1 - python/mxnet/updater.py | 21 --------------------- 3 files changed, 18 insertions(+), 25 deletions(-) delete mode 100644 python/mxnet/updater.py diff --git a/example/cifar10/cifar10_multi_gpus.py b/example/cifar10/cifar10_multi_gpus.py index dd4a57dcd3ee..174d00e7fe93 100644 --- a/example/cifar10/cifar10_multi_gpus.py +++ b/example/cifar10/cifar10_multi_gpus.py @@ -8,8 +8,8 @@ import time # use multiple devices -num_devs = 4 -devs = [mx.gpu(i) for i in range(num_devs)] +num_devs = 1 +devs = [mx.cpu(i) for i in range(num_devs)] mx.kvstore.start() # define the network @@ -96,7 +96,22 @@ def SimpleFactory(data, ch_1x1, ch_3x3): loss = mx.symbol.Softmax(data=fc, name="loss") # define model updater -updater = mx.updater.momentum( + +def momentum(learning_rate=.01, weight_decay=0.0001, momentum=0.9): + """Stochastic Gradient Descent (SGD) updates with momentum + """ + momentums = {} + def momentum_update(key, grad, weight): + # weight += - learning_rate * (grad + weight_decay * weight) + if not momentums.has_key(key): + momentums[key] = mx.nd.zeros(grad.shape) + mom = momentums[key] + mom *= momentum + mom += - learning_rate * (grad + weight_decay * weight) + weight += mom + return momentum_update + +updater = momentum( learning_rate = .05, weight_decay = .0001, momentum = 0.9) mx.kvstore.set_updater(updater) diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 1863fdedc6e1..c591fc29510b 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -13,7 +13,6 @@ from . import ndarray from . import symbol from . import kvstore -from . import updater from . import io # use mx.nd as short for mx.ndarray from . import ndarray as nd diff --git a/python/mxnet/updater.py b/python/mxnet/updater.py deleted file mode 100644 index 57702e01c730..000000000000 --- a/python/mxnet/updater.py +++ /dev/null @@ -1,21 +0,0 @@ -# coding: utf-8 -from __future__ import absolute_import -from .ndarray import zeros - -def momentum(learning_rate = .01, weight_decay = 0.0001, momentum=0.9): - """Stochastic Gradient Descent (SGD) updates with momentum - - Parameters - ---------- - """ - momentums = {} - def momentum_update(key, grad, weight): - # weight += - learning_rate * (grad + weight_decay * weight) - if not momentums.has_key(key): - momentums[key] = zeros(grad.shape) - mom = momentums[key] - mom *= momentum - mom += - learning_rate * (grad + weight_decay * weight) - weight += mom - - return momentum_update From 4ab44b180426d521322588989ad1778744b285ac Mon Sep 17 00:00:00 2001 From: Mu Li Date: Mon, 14 Sep 2015 15:49:03 -0400 Subject: [PATCH 11/11] tiny fix --- example/cifar10/cifar10_multi_gpus.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/cifar10/cifar10_multi_gpus.py b/example/cifar10/cifar10_multi_gpus.py index 174d00e7fe93..a88d8b508f7e 100644 --- a/example/cifar10/cifar10_multi_gpus.py +++ b/example/cifar10/cifar10_multi_gpus.py @@ -8,8 +8,8 @@ import time # use multiple devices -num_devs = 1 -devs = [mx.cpu(i) for i in range(num_devs)] +num_devs = 4 +devs = [mx.gpu(i) for i in range(num_devs)] mx.kvstore.start() # define the network