Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

multi-gpu cifar10 #71

Merged
merged 13 commits into from
Sep 14, 2015
6 changes: 3 additions & 3 deletions example/cifar10/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def RandomInit(narray):
pool = mx.symbol.Pooling(data=in5b, pool_type="avg", kernel=(7,7), name="pool%d" % pool_cnt)
flatten = mx.symbol.Flatten(data=pool, name="flatten1")
fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1")
loss = mx.symbol.Softmax(data=fc, name="sm")
loss = mx.symbol.Softmax(data=fc, name="loss")


epoch = 9
Expand All @@ -169,7 +169,7 @@ def RandomInit(narray):

arg_narrays, grad_narrays = executor.list_arguments()
inputs = dict(zip(loss.list_arguments(), arg_narrays))
tmp_label = mx.nd.zeros(inputs["sm_label"].shape)
tmp_label = mx.nd.zeros(inputs["loss_label"].shape)
momentum_narrays = [mx.nd.zeros(item.shape, mx.gpu()) for item in grad_narrays]

block = list(zip(grad_narrays, arg_narrays, momentum_narrays))
Expand Down Expand Up @@ -241,7 +241,7 @@ def test_cifar():
label = label.asnumpy().flatten()
tmp_label[:] = label
inputs["data"][:] = data
inputs["sm_label"][:] = tmp_label
inputs["loss_label"][:] = tmp_label
executor.forward()
pred[:] = out_narray
train_acc += CalAcc(pred.asnumpy(), label)
Expand Down
261 changes: 261 additions & 0 deletions example/cifar10/cifar10_multi_gpus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# pylint: skip-file
import numpy as np
import mxnet as mx
import copy
import sys
sys.path.append("../../tests/python/common")
import get_data
import time

# use multiple devices
num_devs = 4
devs = [mx.gpu(i) for i in range(num_devs)]
mx.kvstore.start()

# define the network
conv_cnt = 1
concat_cnt = 1
pool_cnt = 1

def ConvFactory(**kwargs):
global conv_cnt
param = copy.copy(kwargs)
act = param["act_type"]
del param["act_type"]
param["workspace"] = 256
param["name"] = "conv%d" % conv_cnt
conv = mx.symbol.Convolution(**param)
bn = mx.symbol.BatchNorm(data = conv, name="bn%d" % conv_cnt)
relu = mx.symbol.Activation(data = bn, name = "%s%d" % (act, conv_cnt), act_type=act)
conv_cnt += 1
return relu

def DownsampleFactory(data, ch_3x3, stride = 2):
global pool_cnt
global concat_cnt
param = {}
# conv 3x3
param["kernel"] = (3, 3)
param["stride"] = (stride, stride)
param["num_filter"] = ch_3x3
param["act_type"] = "relu"
param["data"] = data
param["pad"] = (1, 1)
conv3x3 = ConvFactory(**param)
# pool
del param["num_filter"]
del param["act_type"]
del param["pad"]
param["pool_type"] = "max"
param["name"] = "pool%d" % pool_cnt
pool = mx.symbol.Pooling(**param)
pool_cnt += 1
# concat
concat = mx.symbol.Concat(*[conv3x3, pool], name="concat%d" % concat_cnt)
concat_cnt += 1
return concat

def SimpleFactory(data, ch_1x1, ch_3x3):
global concat_cnt
param = {}
# 1x1
param["kernel"] = (1, 1)
param["num_filter"] = ch_1x1
param["pad"] = (0, 0)
param["stride"] = (1, 1)
param["act_type"] = "relu"
param["data"] = data
conv1x1 = ConvFactory(**param)

# 3x3
param["kernel"] = (3, 3)
param["num_filter"] = ch_3x3
param["pad"] = (1, 1)
conv3x3 = ConvFactory(**param)

#concat
concat = mx.symbol.Concat(*[conv1x1, conv3x3], name="concat%d" % concat_cnt)
concat_cnt += 1
return concat

data = mx.symbol.Variable(name="data")
conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu")
in3a = SimpleFactory(conv1, 32, 32)
in3b = SimpleFactory(in3a, 32, 48)
in3c = DownsampleFactory(in3b, 80)
in4a = SimpleFactory(in3c, 112, 48)
in4b = SimpleFactory(in4a, 96, 64)
in4c = SimpleFactory(in4b, 80, 80)
in4d = SimpleFactory(in4c, 48, 96)
in4e = DownsampleFactory(in4d, 96)
in5a = SimpleFactory(in4e, 176, 160)
in5b = SimpleFactory(in5a, 176, 160)
pool = mx.symbol.Pooling(data=in5b, pool_type="avg", kernel=(7,7), name="pool%d" % pool_cnt)
flatten = mx.symbol.Flatten(data=pool, name="flatten1")
fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1")
loss = mx.symbol.Softmax(data=fc, name="loss")

# define model updater

def momentum(learning_rate=.01, weight_decay=0.0001, momentum=0.9):
"""Stochastic Gradient Descent (SGD) updates with momentum
"""
momentums = {}
def momentum_update(key, grad, weight):
# weight += - learning_rate * (grad + weight_decay * weight)
if not momentums.has_key(key):
momentums[key] = mx.nd.zeros(grad.shape)
mom = momentums[key]
mom *= momentum
mom += - learning_rate * (grad + weight_decay * weight)
weight += mom
return momentum_update

updater = momentum(
learning_rate = .05, weight_decay = .0001, momentum = 0.9)
mx.kvstore.set_updater(updater)

# infer shape
batch_size = 196
batch_size -= (batch_size % num_devs)
data_shape = (batch_size / num_devs, 3, 28, 28)

# create executors for devices
executors = [loss.simple_bind(d, data = mx.nd.empty(data_shape, d)) for d in devs]

# find the params needed to be synchronized between devices
param_names = loss.list_arguments()
sync_prefix = ["weight", "bias", "beta", "gamma"]
sync_indices = [index for index, name in enumerate(param_names)
if any(prefix in name for prefix in sync_prefix)]

sync_weights = [[e.list_arguments()[0][i] for e in executors] for i in sync_indices]
sync_grads = [[e.list_arguments()[1][i] for e in executors] for i in sync_indices]


# init model
weights = executors[0].list_arguments()[0]
for idx in sync_indices:
shape = weights[idx].shape
val = mx.nd.zeros(shape)
if "weight" in param_names[idx]:
val[:] = np.random.uniform(-0.1, 0.1, shape)
elif "gamma" in param_names[idx]:
val[:] = 1.0
mx.kvstore.init(idx, val)

# data reader
get_data.GetCifar10()

train_dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar_mean.bin",
rand_crop=True,
rand_mirror=True,
shuffle=True,
input_shape=(3,28,28),
batch_size=batch_size,
nthread=1)

val_dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/test.rec",
mean_img="data/cifar/cifar_mean.bin",
rand_crop=False,
rand_mirror=False,
input_shape=(3,28,28),
batch_size=batch_size,
nthread=1)


def progress(count, total, epoch, tic):
bar_len = 50
filled_len = int(round(bar_len * count / float(total)))
percents = round(100.0 * count / float(total), 1)
bar = '=' * filled_len + '-' * (bar_len - filled_len)
toc = time.time()
speed = batch_size / float(toc - tic)
suffix = "Epoch %d, Speed: %.2f pic/sec" % (epoch, speed)
sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', suffix))

def cal_acc(out, label):
pred = np.argmax(out, axis=1)
return np.sum(pred == label) * 1.0 / out.shape[0]

def train():
epoch = 7
acc_train = 0.
acc_val = 0.
k = batch_size / num_devs
batch_splits = [range(d*k, (d+1)*k) for d in range(num_devs)]
print("Start training...")
data_in = [e.list_arguments()[0][param_names.index('data')] for e in executors]
label_in = [e.list_arguments()[0][param_names.index('loss_label')] for e in executors]

for i in range(epoch):
# train
start = time.time()
train_acc = 0.0
val_acc = 0.0
train_count = 0
val_count = 0
all_train_bacth = round(50000 / float(batch_size/num_devs) + 1)

for data, label in train_dataiter:
tic = time.time()
# pull weight
mx.kvstore.pull(sync_indices, out = sync_weights)

# forward and backword
data = data.asnumpy()
label = label.asnumpy().flatten()
for d in range(num_devs):
rows = batch_splits[d]
data_in[d][:] = data[rows, :]
label_in[d][:] = label[rows]
executors[d].forward()
executors[d].backward()

# normalize gradient
for grads in sync_grads:
for g in grads:
g /= batch_size

# push gradient
mx.kvstore.push(sync_indices, sync_grads)

# evaluate
for d in range(num_devs):
train_acc += cal_acc(executors[d].outputs[0].asnumpy(),
label[batch_splits[d]])
train_count += 1

progress(train_count, all_train_bacth, i, tic)

# evaluate
for data, label in val_dataiter:
# forward
data = data.asnumpy()
label = label.asnumpy().flatten()
for d in range(num_devs):
rows = batch_splits[d]
data_in[d][:] = data[rows,:]
executors[d].forward()

# eval
for d in range(num_devs):
val_acc += cal_acc(executors[d].outputs[0].asnumpy(),
label[batch_splits[d]])
val_count += 1

sys.stdout.write('\n')

print("Train Acc: %g, Valid Acc: %g, Time: %g sec" % (
train_acc / train_count,
val_acc / val_count,
time.time() - start))

train_dataiter.reset()
val_dataiter.reset()

if __name__ == "__main__":
train()
19 changes: 9 additions & 10 deletions example/mnist/mlp_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import sys
sys.path.append("../../tests/python")
import get_data
import time

# use multiple devices
num_devs = 4
devs = [mx.Context('gpu', i) for i in range(num_devs)]
mx.kvstore.start()

# symbol net
batch_size = 100
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
Expand Down Expand Up @@ -54,7 +54,6 @@ def updater(key, grad, weight):

# create executors for devices
executors = [mlp.bind(devs[d], params[d], grads[d]) for d in range(num_devs)]
forward_out = [mx.nd.zeros(e.heads()[0].shape) for e in executors]

# data reader
get_data.GetMNIST_ubyte()
Expand All @@ -77,6 +76,7 @@ def run_sgd():

num_epochs = 9
for epoch in range(num_epochs):
start = time.time()
print "Epoch %d" % epoch
train_count = 0.0
train_acc = 0.0
Expand All @@ -97,16 +97,14 @@ def run_sgd():
params[d][param_names.index('mlp_label')][:] = label[rows]

executors[d].forward()
executors[d].heads()[0].copyto(forward_out[d])
executors[d].backward([forward_out[d]])

executors[d].backward()
# push gradient
for idx in sync_indices:
mx.kvstore.push(idx, [g[idx] for g in grads])

# eval
for d in range(num_devs):
train_acc += cal_acc(forward_out[d].asnumpy(),
train_acc += cal_acc(executors[d].outputs[0].asnumpy(),
label[batch_splits[d]])
train_count += 1

Expand All @@ -123,15 +121,16 @@ def run_sgd():

# eval
for d in range(num_devs):
val_acc += cal_acc(executors[d].heads()[0].asnumpy(),
val_acc += cal_acc(executors[d].outputs[0].asnumpy(),
label[batch_splits[d]])
val_count += 1

print("Train Acc: ", train_acc / train_count)
print("Valid Acc: ", val_acc / val_count)
print("Train Acc: %g, Valid Acc: %g, time: %g" % (
train_acc / train_count,
val_acc / val_count,
time.time() - start))
train_dataiter.reset()
val_dataiter.reset()

if __name__ == "__main__":
run_sgd()
mx.kvstore.stop()
4 changes: 3 additions & 1 deletion include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class KVStore {
*
* clear all key-value pairs stored, updater, and devices binded
*/
virtual void Stop() { get_impl()->Stop(); delete impl_; impl_ = NULL; }
virtual void Stop() {
if (impl_) { impl_->Stop(); delete impl_; impl_ = NULL; }
}

/**
* \brief Initialize a list of key-value pair to the store.
Expand Down
Loading