Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #50 from antinucleon/master
Browse files Browse the repository at this point in the history
CIFAR-10 Simple Inception
  • Loading branch information
antinucleon committed Sep 8, 2015
2 parents c68a818 + acd7ed1 commit ba28bd8
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 77 deletions.
121 changes: 75 additions & 46 deletions example/cifar10/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import numpy as np
import mxnet as mx
import copy
import sys
sys.path.append("../../tests/python")
import get_data


"""
CXXNET Result:
step1: wmat_lr = 0.05, bias_lr = 0.1, mom = 0.9
Expand Down Expand Up @@ -49,6 +54,10 @@
[39] train-error:0.00125879 val-error:0.0833
[40] train-error:0.000699329 val-error:0.0842
"""
def CalAcc(out, label):
pred = np.argmax(out, axis=1)
return np.sum(pred == label) * 1.0 / out.shape[0]


np.random.seed(1812)

Expand All @@ -62,6 +71,7 @@ def ConvFactory(**kwargs):
act = param["act_type"]
del param["act_type"]
param["name"] = "conv%d" % conv_cnt
param["nstep"] = 64
conv = mx.symbol.Convolution(**param)
bn = mx.symbol.BatchNorm(data = conv, name="bn%d" % conv_cnt)
relu = mx.symbol.Activation(data = bn, name = "%s%d" % (act, conv_cnt), act_type=act)
Expand Down Expand Up @@ -96,6 +106,7 @@ def DownsampleFactory(data, ch_3x3, stride = 2):
concat_cnt += 1
return concat


def SimpleFactory(data, ch_1x1, ch_3x3):
global concat_cnt
param = {}
Expand All @@ -106,7 +117,7 @@ def SimpleFactory(data, ch_1x1, ch_3x3):
param["stride"] = (1, 1)
param["act_type"] = "relu"
param["data"] = data
param["nstep"] = 100
param["nstep"] = 128
conv1x1 = ConvFactory(**param)

# 3x3
Expand All @@ -121,12 +132,11 @@ def SimpleFactory(data, ch_1x1, ch_3x3):
return concat

def RandomInit(narray):
in_num = narray.numpy.shape[1]
out_num = narray.numpy.shape[0]
in_num = narray.shape[1]
out_num = narray.shape[0]
a = np.sqrt(3.0 / (in_num + out_num))
tmp = mx.narray.create((narray.numpy.shape))
tmp.numpy[:] = np.random.uniform(-a, a, narray.numpy.shape)
tmp.copyto(narray)
tmp = mx.narray.array(np.random.uniform(-a, a, narray.shape))
narray[:] = tmp

data = mx.symbol.Variable(name="data")
conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu")
Expand All @@ -143,110 +153,129 @@ def RandomInit(narray):
pool = mx.symbol.Pooling(data=in5b, pool_type="avg", kernel=(7,7), name="pool%d" % pool_cnt)
flatten = mx.symbol.Flatten(data=pool, name="flatten1")
fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1")
loss = mx.symbol.Softmax(data=fc, name="softmax")
loss = mx.symbol.Softmax(data=fc, name="sm")

args_list = loss.list_arguments()

data_shape = (128, 3, 28, 28)

batch_size = 128
data_shape = (batch_size, 3, 28, 28)
arg_shapes, out_shapes, aux_shapes = loss.infer_shape(data=data_shape)

arg_narrays = [mx.narray.create(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
grad_narrays = [mx.narray.create(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
arg_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
grad_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
mom_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
aux_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in aux_shapes]

inputs = dict(zip(args_list, arg_narrays))

name2shape = dict(zip(args_list, arg_shapes))
pred = mx.narray.create(out_shapes[0])
pred = mx.narray.zeros(out_shapes[0])

np.random.seed(0)
# set random weight

for name, narray in inputs.items():
if "weight" in name:
tmp = mx.narray.create(name2shape[name])
tmp.numpy[:] = np.random.uniform(-0.07, 0.07, name2shape[name])
tmp.copyto(narray)
narray[:] = np.random.uniform(-0.1, 0.1, narray.shape)
if "bias" in name:
narray[:] = 0.0
if "gamma" in name:
narray[:] = 1.0
if "beta" in name:
narray[:] = 0.0

# bind executer
# TODO(bing): think of a better bind interface
executor = loss.bind(mx.Context('gpu'), arg_narrays, grad_narrays)
executor = loss.bind(mx.Context('gpu'), arg_narrays, grad_narrays, 'write', aux_narrays)
# update

out_narray = executor.heads()[0]
grad_narray = mx.narray.create(out_narray.shape)

epoch = 9
lr = 0.1
wd = 0.0004
lr = 0.05
wd = 0.0001
momentum = 0.9

def Update(grad, weight):
weight[:] -= lr * grad / batch_size
def Update(grad, weight, mom):
mom[:] *= momentum
mom[:] += -lr * (grad / batch_size + wd * weight)
weight[:] += mom

block = list(zip(grad_narrays, arg_narrays))
block = list(zip(grad_narrays, arg_narrays, mom_narrays))

#check data
get_data.GetCifar10()

train_dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/train.rec",
mean_img="data/cifar/cifar_mean.bin",
rand_crop=True,
rand_mirror=True,
input_shape=(3,28,28),
batch_size=128,
batch_size=batch_size,
nthread=1)
test_dataiter = mx.io.ImageRecordIter(
path_imgrec="data/cifar/test.rec",
mean_img="data/cifar/cifar_mean.bin",
rand_crop=True,
rand_mirror=True,
rand_crop=False,
rand_mirror=False,
input_shape=(3,28,28),
batch_size=100,
batch_size=batch_size,
nthread=1)

tmp_label = mx.narray.create(name2shape["sm_label"])
tmp_label = mx.narray.zeros(name2shape["sm_label"])

def progress(count, total, suffix=''):
bar_len = 80
filled_len = int(round(bar_len * count / float(total)))

percents = round(100.0 * count / float(total), 1)
bar = '=' * filled_len + '-' * (bar_len - filled_len)

sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', suffix))

def test_cifar():
acc_train = 0.
acc_val = 0.
print("Start training...")
for i in range(epoch):
# train
print("Epoch %d" % i)
train_acc = 0.0
val_acc = 0.0
train_nbatch = 0
val_nbatch = 0
all_train_bacth = 50000 / float(batch_size)
for data, label in train_dataiter:
data = data
tmp_label.numpy[:] = label.numpy.reshape(tmp_label.shape)
data.copyto(inputs["data"])
tmp_label.copyto(inputs["sm_label"])
progress(train_nbatch, all_train_bacth, "Epoch %d" % i)
label = label.asnumpy().flatten()
tmp_label[:] = label
inputs["data"][:] = data
inputs["sm_label"][:] = tmp_label
executor.forward()
out_narray.copyto(pred)
train_acc += CalAcc(pred.numpy, label.numpy.flatten())
pred[:] = out_narray
train_acc += CalAcc(pred.asnumpy(), label)
train_nbatch += 1
out_narray.copyto(grad_narray)
executor.backward([grad_narray])
executor.backward([out_narray])

for grad, weight in block:
Update(grad, weight)
for grad, weight, mom in block:
Update(grad, weight, mom)

# evaluate
for data, label in val_dataiter:
data = data
label = label.numpy.flatten()
data.copyto(inputs["data"])
for data, label in test_dataiter:
label = label.asnumpy().flatten()
inputs["data"][:] = data
executor.forward()
out_narray.copyto(pred)
val_acc += CalAcc(pred.numpy, label)
pred[:] = out_narray
val_acc += CalAcc(pred.asnumpy(), label)
val_nbatch += 1
acc_train = train_acc / train_nbatch
acc_val = val_acc / val_nbatch
sys.stdout.write('\n')
print("Train Acc: ", train_acc / train_nbatch)
print("Valid Acc: ", val_acc / val_nbatch)
train_dataiter.reset()
val_dataiter.reset()
assert(acc_train > 0.98)
assert(acc_val > 0.97)
test_dataiter.reset()

if __name__ == "__main__":
test_cifar()
55 changes: 30 additions & 25 deletions example/mnist/mlp_gpu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# pylint: skip-file

import mxnet as mx
import numpy as np
import os, gzip
import pickle as pickle
import sys
sys.path.append("../../tests/python")
import get_data



def CalAcc(out, label):
pred = np.argmax(out, axis=1)
return np.sum(pred == label) * 1.0 / out.shape[0]
Expand All @@ -20,40 +24,42 @@ def CalAcc(out, label):
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
softmax = mx.symbol.Softmax(data = fc3, name = 'sm')
args_list = softmax.list_arguments()

# infer shape
data_shape = (batch_size, 784)
arg_shapes, out_shapes, aux_shapes = softmax.infer_shape(data=data_shape)

arg_narrays = [mx.narray.create(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
grad_narrays = [mx.narray.create(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]

# create GPU NArray for data
arg_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
grad_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
inputs = dict(zip(args_list, arg_narrays))

# create CPU NArray for result stat
name2shape = dict(zip(args_list, arg_shapes))
pred = mx.narray.create(out_shapes[0])
pred = mx.narray.zeros(out_shapes[0])


np.random.seed(0)
# set random weight
np.random.seed(0)
for name, narray in inputs.items():
if "weight" in name:
tmp = mx.narray.create(name2shape[name])
tmp.numpy[:] = np.random.uniform(-0.07, 0.07, name2shape[name])
tmp = mx.narray.array(np.random.uniform(-0.07, 0.07, name2shape[name]))
tmp.copyto(narray)
if "bias" in name:
narray[:] = 0.0

# bind executer
# TODO(bing): think of a better bind interface
executor = softmax.bind(mx.Context('gpu'), arg_narrays, grad_narrays)
# update

# create gradient NArray
out_narray = executor.heads()[0]
grad_narray = mx.narray.create(out_narray.shape)
grad_narray = mx.narray.zeros(out_narray.shape, ctx=mx.Context("gpu"))


# update
epoch = 9
lr = 0.1
wd = 0.0004

# SGD Update rule
def Update(grad, weight):
weight[:] -= lr * grad / batch_size

Expand All @@ -71,7 +77,7 @@ def Update(grad, weight):
label="data/t10k-labels-idx1-ubyte",
batch_size=batch_size, shuffle=True, flat=True, silent=False)

tmp_label = mx.narray.create(name2shape["sm_label"])
tmp_label = mx.narray.zeros(name2shape["sm_label"])

def test_mlp():
acc_train = 0.
Expand All @@ -84,28 +90,27 @@ def test_mlp():
train_nbatch = 0
val_nbatch = 0
for data, label in train_dataiter:
data = data
tmp_label.numpy[:] = label.numpy.reshape(tmp_label.shape)
data.copyto(inputs["data"])
tmp_label.copyto(inputs["sm_label"])
label = label.asnumpy().reshape(tmp_label.shape)
tmp_label[:] = label
inputs["data"][:] = data
inputs["sm_label"][:] = tmp_label
executor.forward()
out_narray.copyto(pred)
train_acc += CalAcc(pred.numpy, label.numpy.flatten())
pred[:] = out_narray
train_acc += CalAcc(pred.asnumpy(), label)
train_nbatch += 1
out_narray.copyto(grad_narray)
grad_narray[:] = out_narray
executor.backward([grad_narray])

for grad, weight in block:
Update(grad, weight)

# evaluate
for data, label in val_dataiter:
data = data
label = label.numpy.flatten()
data.copyto(inputs["data"])
label = label.asnumpy().flatten()
inputs["data"][:] = data
executor.forward()
out_narray.copyto(pred)
val_acc += CalAcc(pred.numpy, label)
pred[:] = out_narray
val_acc += CalAcc(pred.asnumpy(), label)
val_nbatch += 1
acc_train = train_acc / train_nbatch
acc_val = val_acc / val_nbatch
Expand Down
Loading

0 comments on commit ba28bd8

Please sign in to comment.