Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #271 from mli/master
Browse files Browse the repository at this point in the history
[test] testcases for multi-devices
  • Loading branch information
mli committed Oct 12, 2015
2 parents aaf231f + db0bcd4 commit 0125990
Show file tree
Hide file tree
Showing 9 changed files with 356 additions and 73 deletions.
73 changes: 0 additions & 73 deletions tests/python/distributed/test_mlp.py

This file was deleted.

7 changes: 7 additions & 0 deletions tests/python/multi-node/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Test multi-devices and multi-machines

must disable `CUDNN`

`local_*` for multi-devices and single machine.

`dist_*` for multi-machines
78 changes: 78 additions & 0 deletions tests/python/multi-node/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# pylint: skip-file
""" common for multi-node
- all iterators are disabled randomness
"""
import sys
sys.path.insert(0, "../common/")
sys.path.insert(0, "../../python/")
import mxnet as mx
import get_data
import numpy as np
import logging

def mnist(batch_size, input_shape, num_parts=1, part_index=0):
"""return mnist iters"""
get_data.GetMNIST_ubyte()
flat = len(input_shape)==1
train = mx.io.MNISTIter(
image = "data/train-images-idx3-ubyte",
label = "data/train-labels-idx1-ubyte",
data_shape = input_shape,
batch_size = batch_size,
num_parts = num_parts,
part_index = part_index,
shuffle = False,
flat = flat,
silent = False)
val = mx.io.MNISTIter(
image = "data/t10k-images-idx3-ubyte",
label = "data/t10k-labels-idx1-ubyte",
data_shape = input_shape,
batch_size = batch_size,
shuffle = False,
flat = flat,
silent = False)
return (train, val)

def cifar10(batch_size, input_shape, num_parts=1, part_index=0):
"""return cifar10 iterator"""
get_data.GetCifar10()

train = mx.io.ImageRecordIter(
path_imgrec = "data/cifar/train.rec",
mean_img = "data/cifar/cifar_mean.bin",
data_shape = input_shape,
batch_size = batch_size,
rand_crop = False,
rand_mirror = False,
shuffle = False,
round_batch = False,
num_parts = num_parts,
part_index = part_index)
val = mx.io.ImageRecordIter(
path_imgrec = "data/cifar/test.rec",
mean_img = "data/cifar/cifar_mean.bin",
rand_crop = False,
rand_mirror = False,
shuffle = False,
round_batch = False,
data_shape = (3,28,28),
batch_size = batch_size)
return (train, val)

def accuracy(model, data):
"""evaluate acc"""
# predict
data.reset()
prob = model.predict(data)
py = np.argmax(prob, axis=1)
# get label
data.reset()
y = np.concatenate([label.asnumpy() for _, label in data]).astype('int')
y = y[0:len(py)]
acc = float(np.sum(py == y)) / len(y)
logging.info('Accuracy = %f', acc)

return acc
36 changes: 36 additions & 0 deletions tests/python/multi-node/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# pylint: skip-file
""" data iterator for multi-node.
all iterators are disabled randomness
must create kv before
"""
import sys
sys.path.insert(0, "../common/")
sys.path.insert(0, "../../python/")
import mxnet as mx
import get_data

def mnist(batch_size, input_shape, num_parts=1, part_index=0):
"""return mnist iters"""
get_data.GetMNIST_ubyte()
flat = len(input_shape)==1
train = mx.io.MNISTIter(
image = "data/train-images-idx3-ubyte",
label = "data/train-labels-idx1-ubyte",
data_shape = input_shape,
batch_size = batch_size,
num_parts = num_parts,
part_index = part_index,
shuffle = False,
flat = flat,
silent = False)
val = mx.io.MNISTIter(
image = "data/t10k-images-idx3-ubyte",
label = "data/t10k-labels-idx1-ubyte",
data_shape = input_shape,
batch_size = batch_size,
shuffle = False,
flat = flat,
silent = False)
return (train, val)
85 changes: 85 additions & 0 deletions tests/python/multi-node/local_inception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/usr/bin/env python
# pylint: skip-file
import mxnet as mx
from common import cifar10, accuracy
import logging

# symbol

# Basic Conv + BN + ReLU factory
def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), act_type="relu"):
conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)
bn = mx.symbol.BatchNorm(data=conv)
act = mx.symbol.Activation(data = bn, act_type=act_type)
return act

# A Simple Downsampling Factory
def DownsampleFactory(data, ch_3x3):
# conv 3x3
conv = ConvFactory(data=data, kernel=(3, 3), stride=(2, 2), num_filter=ch_3x3, pad=(1, 1))
# pool
pool = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type='max')
# concat
concat = mx.symbol.Concat(*[conv, pool])
return concat

# A Simple module
def SimpleFactory(data, ch_1x1, ch_3x3):
# 1x1
conv1x1 = ConvFactory(data=data, kernel=(1, 1), pad=(0, 0), num_filter=ch_1x1)
# 3x3
conv3x3 = ConvFactory(data=data, kernel=(3, 3), pad=(1, 1), num_filter=ch_3x3)
#concat
concat = mx.symbol.Concat(*[conv1x1, conv3x3])
return concat

data = mx.symbol.Variable(name="data")
conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu")
in3a = SimpleFactory(conv1, 32, 32)
in3b = SimpleFactory(in3a, 32, 48)
in3c = DownsampleFactory(in3b, 80)
in4a = SimpleFactory(in3c, 112, 48)
in4b = SimpleFactory(in4a, 96, 64)
in4c = SimpleFactory(in4b, 80, 80)
in4d = SimpleFactory(in4c, 48, 96)
in4e = DownsampleFactory(in4d, 96)
in5a = SimpleFactory(in4e, 176, 160)
in5b = SimpleFactory(in5a, 176, 160)
pool = mx.symbol.Pooling(data=in5b, pool_type="avg", kernel=(7,7), name="global_pool")
flatten = mx.symbol.Flatten(data=pool, name="flatten1")
fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1")
softmax = mx.symbol.Softmax(data=fc, name="loss")

def test_inception(devs, kv_type):
# guarantee the same weight init for each run
mx.random.seed(0)
logging.basicConfig(level=logging.DEBUG)

(train, val) = cifar10(batch_size = 128, input_shape=(3,28,28))

model = mx.model.FeedForward.create(
ctx = devs,
symbol = softmax,
X = train,
kvstore = kv_type,
eval_data = val,
num_round = 1,
learning_rate = 0.1,
momentum = 0.9,
wd = 0.00001,
initializer = mx.init.Uniform(0.07))

return accuracy(model, val)

if __name__ == "__main__":
# base = test_inception(mx.gpu(), 'none')

gpus = [mx.gpu(i) for i in range(2)]
acc1 = test_inception(gpus, 'local_update_cpu')
# acc2 = test_inception(gpus, 'local_allreduce_cpu')
# acc3 = test_inception(gpus, 'local_allreduce_device')

# assert base > 0.95
# assert abs(base - acc1) < 1e-3
# assert abs(base - acc2) < 1e-3
# assert abs(base - acc3) < 1e-3
63 changes: 63 additions & 0 deletions tests/python/multi-node/local_lenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#!/usr/bin/env python
# pylint: skip-file
import mxnet as mx
from common import mnist, accuracy, cifar10
import logging

## define lenet
# input
data = mx.symbol.Variable('data')
# first conv
conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)
tanh1 = mx.symbol.Activation(data=conv1, act_type="tanh")
pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max",
kernel=(2,2), stride=(2,2))
# second conv
conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)
tanh2 = mx.symbol.Activation(data=conv2, act_type="tanh")
pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max",
kernel=(2,2), stride=(2,2))
# first fullc
flatten = mx.symbol.Flatten(data=pool2)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3 = mx.symbol.Activation(data=fc1, act_type="tanh")
# second fullc
fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=10)
# loss
lenet = mx.symbol.Softmax(data=fc2)

def test_lenet(devs, kv_type):
# guarantee the same weight init for each run
mx.random.seed(0)
logging.basicConfig(level=logging.DEBUG)

# (train, val) = cifar10(batch_size = 128, input_shape=(3,28,28))
(train, val) = mnist(batch_size = 100, input_shape=(1,28,28))

model = mx.model.FeedForward.create(
ctx = devs,
kvstore = kv_type,
symbol = lenet,
X = train,
num_round = 3,
learning_rate = 0.1,
momentum = 0.9,
wd = 0.00001)

return accuracy(model, val)

if __name__ == "__main__":
gpus = [mx.gpu(i) for i in range(2)]

base = test_lenet(mx.gpu(), 'none')
acc1 = test_lenet(mx.gpu(), 'none')
acc2 = test_lenet(gpus, 'local_update_cpu')
acc3 = test_lenet(gpus, 'local_allreduce_cpu')
acc4 = test_lenet(gpus, 'local_allreduce_device')

assert base > 0.95
# assert base > 0.5
assert abs(base - acc1) < 1e-3
assert abs(base - acc2) < 1e-3
assert abs(base - acc3) < 1e-3
assert abs(base - acc4) < 1e-3
Loading

0 comments on commit 0125990

Please sign in to comment.