Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #2 from dmlc/master
Browse files Browse the repository at this point in the history
rebase
  • Loading branch information
mli committed Sep 13, 2015
2 parents 144d538 + 3d53ac7 commit d5cebc8
Show file tree
Hide file tree
Showing 81 changed files with 2,553 additions and 1,819 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ __pycache__
build
dmlc-core
mshadow
data
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ env:
- TASK=python CXX=g++
- TASK=python3 CXX=g++
- TASK=python_naive CXX=g++
- TASK=unittest_gtest CXX=g++
- TASK=cpp_unittest CXX=g++

# dependent apt packages
addons:
Expand Down Expand Up @@ -47,6 +47,7 @@ before_install:

install:
- pip install cpplint pylint --user `whoami`
- make -f dmlc-core/scripts/packages.mk gtest
- if [ "$CXX" = "g++" ]; then export CXX="g++-4.8" CC="gcc-4.8"; fi

script:
Expand Down
20 changes: 6 additions & 14 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,12 @@ else
CFLAGS+= -DMXNET_USE_OPENCV=0
endif

# setup opencv
ifeq ($(USE_OPENCV_DECODER),1)
CFLAGS+= -DMXNET_USE_OPENCV_DECODER=1
else
CFLAGS+= -DMXNET_USE_OPENCV_DECODER=0
endif

ifeq ($(USE_OPENMP_ITER), 1)
ifeq ($(USE_OPENMP), 1)
CFLAGS += -fopenmp
endif

ifeq ($(USE_CUDNN), 1)
CFLAGS += -DCXXNET_USE_CUDNN=1
CFLAGS += -DMSHADOW_USE_CUDNN=1
LDFLAGS += -lcudnn
endif

Expand All @@ -84,7 +77,6 @@ endif

.PHONY: clean all test lint doc

BIN = tests/test_threaded_engine
all: lib/libmxnet.a lib/libmxnet.so $(BIN)

SRC = $(wildcard src/*.cc src/*/*.cc)
Expand Down Expand Up @@ -114,13 +106,13 @@ lib/libmxnet.a: $(ALL_DEP)
lib/libmxnet.so: $(ALL_DEP)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)

tests/% : tests/%.cc lib/libmxnet.a
$(CXX) -std=c++0x $(CFLAGS) -MM -MT tests/$*.o $< >tests/$*.d
$(CXX) $(CFLAGS) -std=c++0x -o $@ $(filter %.cc %.a, $^) $(LDFLAGS)

$(DMLC_CORE)/libdmlc.a:
+ cd $(DMLC_CORE); make libdmlc.a config=$(ROOTDIR)/$(config); cd $(ROOTDIR)

include tests/cpp/unittest.mk

test: tests/cpp/unittest

lint:
python dmlc-core/scripts/lint.py mxnet ${LINT_LANG} include src scripts python

Expand Down
1 change: 1 addition & 0 deletions doc/sphinx_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def run_build_mxnet(folder):
subprocess.call('cd ..; rm -rf mshadow;' +
'git clone https://github.com/dmlc/mshadow', shell = True)
subprocess.call('cd ..; cp make/readthedocs.mk config.mk', shell = True)
subprocess.call('cd ..; rm -rf build', shell = True)
retcode = subprocess.call("cd %s; make" % folder, shell = True)
if retcode < 0:
sys.stderr.write("build terminated by signal %s" % (-retcode))
Expand Down
66 changes: 29 additions & 37 deletions example/cifar10/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
sys.path.append("../../tests/python")
import get_data

import time

"""
CXXNET Result:
Expand Down Expand Up @@ -70,8 +70,8 @@ def ConvFactory(**kwargs):
param = copy.copy(kwargs)
act = param["act_type"]
del param["act_type"]
param["workspace"] = 256
param["name"] = "conv%d" % conv_cnt
param["nstep"] = 64
conv = mx.symbol.Convolution(**param)
bn = mx.symbol.BatchNorm(data = conv, name="bn%d" % conv_cnt)
relu = mx.symbol.Activation(data = bn, name = "%s%d" % (act, conv_cnt), act_type=act)
Expand All @@ -89,13 +89,11 @@ def DownsampleFactory(data, ch_3x3, stride = 2):
param["num_filter"] = ch_3x3
param["act_type"] = "relu"
param["data"] = data
param["nstep"] = 100
param["pad"] = (1, 1)
conv3x3 = ConvFactory(**param)
# pool
del param["num_filter"]
del param["act_type"]
del param["nstep"]
del param["pad"]
param["pool_type"] = "max"
param["name"] = "pool%d" % pool_cnt
Expand All @@ -117,7 +115,6 @@ def SimpleFactory(data, ch_1x1, ch_3x3):
param["stride"] = (1, 1)
param["act_type"] = "relu"
param["data"] = data
param["nstep"] = 128
conv1x1 = ConvFactory(**param)

# 3x3
Expand All @@ -135,15 +132,15 @@ def RandomInit(narray):
in_num = narray.shape[1]
out_num = narray.shape[0]
a = np.sqrt(3.0 / (in_num + out_num))
tmp = mx.narray.array(np.random.uniform(-a, a, narray.shape))
tmp = mx.nd.array(np.random.uniform(-a, a, narray.shape))
narray[:] = tmp

data = mx.symbol.Variable(name="data")
conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type="relu")
in3a = SimpleFactory(conv1, 32, 32)
in3b = SimpleFactory(in3a, 32, 48)
in3c = DownsampleFactory(in3b, 80)
in4a = SimpleFactory(in3c, 112, 38)
in4a = SimpleFactory(in3c, 112, 48)
in4b = SimpleFactory(in4a, 96, 64)
in4c = SimpleFactory(in4b, 80, 80)
in4d = SimpleFactory(in4c, 48, 96)
Expand All @@ -155,22 +152,27 @@ def RandomInit(narray):
fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10, name="fc1")
loss = mx.symbol.Softmax(data=fc, name="sm")

args_list = loss.list_arguments()

epoch = 9
lr = 0.05
wd = 0.0001
momentum = 0.9

batch_size = 128
data_shape = (batch_size, 3, 28, 28)
arg_shapes, out_shapes, aux_shapes = loss.infer_shape(data=data_shape)

arg_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
grad_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
mom_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
aux_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in aux_shapes]
in_data = mx.nd.empty(data_shape, mx.gpu())
executor = loss.simple_bind(mx.gpu(), data = in_data)

out_narray = executor.outputs[0]
pred = mx.nd.zeros(out_narray.shape, mx.cpu())

inputs = dict(zip(args_list, arg_narrays))
arg_narrays, grad_narrays = executor.list_arguments()
inputs = dict(zip(loss.list_arguments(), arg_narrays))
tmp_label = mx.nd.zeros(inputs["sm_label"].shape)
momentum_narrays = [mx.nd.zeros(item.shape, mx.gpu()) for item in grad_narrays]

name2shape = dict(zip(args_list, arg_shapes))
pred = mx.narray.zeros(out_shapes[0])
block = list(zip(grad_narrays, arg_narrays, momentum_narrays))

np.random.seed(0)
# set random weight
Expand All @@ -185,25 +187,11 @@ def RandomInit(narray):
if "beta" in name:
narray[:] = 0.0

# bind executer
# TODO(bing): think of a better bind interface
executor = loss.bind(mx.Context('gpu'), arg_narrays, grad_narrays, 'write', aux_narrays)
# update

out_narray = executor.heads()[0]

epoch = 9
lr = 0.05
wd = 0.0001
momentum = 0.9

def Update(grad, weight, mom):
mom[:] *= momentum
mom[:] += -lr * (grad / batch_size + wd * weight)
weight[:] += mom

block = list(zip(grad_narrays, arg_narrays, mom_narrays))

#check data
get_data.GetCifar10()

Expand All @@ -224,17 +212,19 @@ def Update(grad, weight, mom):
batch_size=batch_size,
nthread=1)

tmp_label = mx.narray.zeros(name2shape["sm_label"])

def progress(count, total, suffix=''):
bar_len = 80
def progress(count, total, epoch, toc):
bar_len = 50
filled_len = int(round(bar_len * count / float(total)))

percents = round(100.0 * count / float(total), 1)
bar = '=' * filled_len + '-' * (bar_len - filled_len)

tic = time.time()
speed = batch_size / float(tic - toc)
suffix = "Epoch %d, Speed: %.2f pic/sec" % (epoch, speed)
sys.stdout.write('[%s] %s%s ...%s\r' % (bar, percents, '%', suffix))


def test_cifar():
acc_train = 0.
acc_val = 0.
Expand All @@ -245,9 +235,9 @@ def test_cifar():
val_acc = 0.0
train_nbatch = 0
val_nbatch = 0
all_train_bacth = 50000 / float(batch_size)
all_train_bacth = round(50000 / float(batch_size) + 1)
for data, label in train_dataiter:
progress(train_nbatch, all_train_bacth, "Epoch %d" % i)
toc = time.time()
label = label.asnumpy().flatten()
tmp_label[:] = label
inputs["data"][:] = data
Expand All @@ -256,10 +246,12 @@ def test_cifar():
pred[:] = out_narray
train_acc += CalAcc(pred.asnumpy(), label)
train_nbatch += 1
executor.backward([out_narray])
#executor.backward([out_narray])
executor.backward()

for grad, weight, mom in block:
Update(grad, weight, mom)
progress(train_nbatch, all_train_bacth, i, toc)

# evaluate
for data, label in test_dataiter:
Expand Down
14 changes: 7 additions & 7 deletions example/mnist/mlp_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,28 @@ def CalAcc(out, label):
arg_shapes, out_shapes, aux_shapes = softmax.infer_shape(data=data_shape)

# create GPU NArray for data
arg_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
grad_narrays = [mx.narray.zeros(shape, ctx=mx.Context("gpu")) for shape in arg_shapes]
arg_narrays = [mx.nd.zeros(shape, ctx=mx.gpu()) for shape in arg_shapes]
grad_narrays = [mx.nd.zeros(shape, ctx=mx.gpu()) for shape in arg_shapes]
inputs = dict(zip(args_list, arg_narrays))

# create CPU NArray for result stat
name2shape = dict(zip(args_list, arg_shapes))
pred = mx.narray.zeros(out_shapes[0])
pred = mx.nd.zeros(out_shapes[0])


# set random weight
np.random.seed(0)
for name, narray in inputs.items():
if "weight" in name:
tmp = mx.narray.array(np.random.uniform(-0.07, 0.07, name2shape[name]))
tmp = mx.nd.array(np.random.uniform(-0.07, 0.07, name2shape[name]))
tmp.copyto(narray)

# bind executer
# TODO(bing): think of a better bind interface
executor = softmax.bind(mx.Context('gpu'), arg_narrays, grad_narrays)
# create gradient NArray
out_narray = executor.heads()[0]
grad_narray = mx.narray.zeros(out_narray.shape, ctx=mx.Context("gpu"))
out_narray = executor.outputs[0]
grad_narray = mx.nd.zeros(out_narray.shape, ctx=mx.gpu())


# update
Expand All @@ -77,7 +77,7 @@ def Update(grad, weight):
label="data/t10k-labels-idx1-ubyte",
batch_size=batch_size, shuffle=True, flat=True, silent=False)

tmp_label = mx.narray.zeros(name2shape["sm_label"])
tmp_label = mx.nd.zeros(name2shape["sm_label"])

def test_mlp():
acc_train = 0.
Expand Down
8 changes: 4 additions & 4 deletions example/mnist/mlp_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,18 @@ def updater(key, grad, weight):
np.random.seed(0)
for idx in sync_indices:
shape = param_shapes[idx]
val = mx.narray.zeros(shape)
val = mx.nd.zeros(shape)
if "weight" in param_names[idx]:
val[:] = np.random.uniform(-0.07, 0.07, shape)
mx.kvstore.init(idx, val)

# allocate device's memory
params = [[mx.narray.zeros(s, d) for s in param_shapes] for d in devs]
grads = [[mx.narray.zeros(s, d) for s in param_shapes] for d in devs]
params = [[mx.nd.zeros(s, d) for s in param_shapes] for d in devs]
grads = [[mx.nd.zeros(s, d) for s in param_shapes] for d in devs]

# create executors for devices
executors = [mlp.bind(devs[d], params[d], grads[d]) for d in range(num_devs)]
forward_out = [mx.narray.zeros(e.heads()[0].shape) for e in executors]
forward_out = [mx.nd.zeros(e.heads()[0].shape) for e in executors]

# data reader
get_data.GetMNIST_ubyte()
Expand Down
17 changes: 15 additions & 2 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,21 @@
*\brief whether to use cudnn library for convolution
*/
#ifndef MXNET_USE_CUDNN
#define MXNET_USE_CUDNN 0
#define MXNET_USE_CUDNN MSHADOW_USE_CUDNN
#endif

/*! \brief Error message for using gpu when MXNET_USE_CUDA==0 */
#define MXNET_GPU_NOT_ENABLED_ERROR "GPU is not enabled"

/*!
* \brief define compatible keywords in g++
* Used to support g++-4.6 and g++4.7
*/
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ == 4 && __GNUC_MINOR__ == 6
#define override
#define final
#endif
#endif

/*! \brief namespace of mxnet */
Expand All @@ -50,7 +64,6 @@ typedef mshadow::TShape TShape;
typedef mshadow::TBlob TBlob;
} // namespace mxnet


//! \cond Doxygen_Suppress
namespace dmlc {
// Add a few patches to support TShape in dmlc/parameter.
Expand Down
Loading

0 comments on commit d5cebc8

Please sign in to comment.