Skip to content

Commit

Permalink
Merge pull request apache#9 from dato-code/merge_with_upstream
Browse files Browse the repository at this point in the history
Merge with upstream
  • Loading branch information
Jay Gu committed Jan 7, 2016
2 parents a312fd3 + a774209 commit d167481
Show file tree
Hide file tree
Showing 47 changed files with 1,980 additions and 86 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ bin/im2rec: tools/im2rec.cc $(ALL_DEP)

$(BIN) :
@mkdir -p $(@D)
$(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS)
$(CXX) $(CFLAGS) -std=c++0x -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS)

include tests/cpp/unittest.mk

Expand Down Expand Up @@ -208,11 +208,11 @@ rpkg: roxygen

clean:
$(RM) -r build lib bin *~ */*~ */*/*~ */*/*/*~

clean_all: clean
cd $(DMLC_CORE); make clean; cd -
cd $(PS_PATH); make clean; cd -

clean_all: clean

-include build/*.d
-include build/*/*.d
ifneq ($(EXTRA_OPERATORS),)
Expand Down
6 changes: 5 additions & 1 deletion cmake/Utils.cmake
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# For cmake_parse_arguments
include(CMakeParseArguments)

################################################################################################
# Command alias for debugging messages
# Usage:
Expand Down Expand Up @@ -395,4 +398,5 @@ function(mxnet_source_group group)
file(GLOB_RECURSE srcs2 ${CAFFE_SOURCE_GROUP_GLOB_RECURSE})
source_group(${group} FILES ${srcs2})
endif()
endfunction()
endfunction()

2 changes: 1 addition & 1 deletion dmlc-core
27 changes: 26 additions & 1 deletion doc/build.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Our goal is to build the shared library:
The minimal building requirement is

- A recent c++ compiler supporting C++ 11 such as `g++ >= 4.8` or `clang`
- A BLAS library, such as `libblas`, `libblas`, `openblas` `intel mkl`
- A BLAS library, such as `libblas`, `atlas`, `openblas` or `intel mkl`

Optional libraries

Expand Down Expand Up @@ -239,6 +239,31 @@ Now you should have the R package as a tar.gz file and you can install it as a n
R CMD INSTALL mxnet_0.5.tar.gz
```
To install the package using GPU on Windows without building the package from scratch. Note that you need a couple of programs installed already:
- You'll need the [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit). This depends on Visual Studio, and a free compatible version would be [Visual Studio Community 2013](https://www.visualstudio.com/en-us/news/vs2013-community-vs.aspx). For instructions and compatibility checks, read http://docs.nvidia.com/cuda/cuda-getting-started-guide-for-microsoft-windows/ .

- You will also need to register as a developer at nvidia and download CUDNN V3, https://developer.nvidia.com/cudnn .


1. Download the mxnet package as a ZIP from the Github repository https://github.com/dmlc/mxnet and unpack it. You will be editing the `/mxnet/R-package` folder.

2. Download the most recent GPU-enabled package from the [Releases tab](https://github.com/dmlc/mxnet/releases). Unzip this file so you have a folder `/nocudnn`. Note that this file and the folder you'll save it in will be used for future reference and not directly for installing the package. Only some files will be copied from it into the `R-package` folder.
(Note: you now have 2 folders we're working with, possibly in different locations, that we'll reference with `R-package/` and `nocudnn/`.)
3. Download CUDNN V3 from https://developer.nvidia.com/cudnn. Unpack the .zip file and you'll see 3 folders, `/bin`, `/include`, `/lib`. Copy and replace these 3 folders into `nocudnn/3rdparty/cudnn/`, or unpack the .zip file there directly.

4. Create the folder `R-package/inst/libs/x64`. We only support 64-bit operating system now, so you need the x64 folder;

5. Put dll files in `R-package/inst/libs/x64`.

The first dll file you need is `nocudnn/lib/libmxnet.dll`. The other dll files you need are the ones in all 4 subfolders of `nocudnn/3rdparty/`, for the `cudnn` and `openblas` you'll need to look in the `/bin` folders. There should be 11 dll files now in `R-package/inst/libs/x64`.

6. Copy the folder `nocudnn/include/` to `R-package/inst/`. So now you should have a folder `R-package/inst/include/` with 3 subfolders.

7. Run `R CMD INSTALL --no-multiarch R-package`. Make sure that R is added to your PATH in Environment Variables. Running the command `Where R` in Command Prompt should return the location.

Note on Library Build:

We isolate the library build with Rcpp end to maximize the portability
Expand Down
4 changes: 2 additions & 2 deletions example/autoencoder/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def make_decoder(self, feature, dims, sparseness_penalty=None, dropout=None, int
def layerwise_pretrain(self, X, batch_size, n_iter, optimizer, l_rate, decay, lr_scheduler=None):
def l2_norm(label, pred):
return np.mean(np.square(label-pred))/2.0
solver = Solver('sgd', momentum=0.9, wd=decay, learning_rate=l_rate, lr_scheduler=lr_scheduler)
solver = Solver(optimizer, momentum=0.9, wd=decay, learning_rate=l_rate, lr_scheduler=lr_scheduler)
solver.set_metric(mx.metric.CustomMetric(l2_norm))
solver.set_monitor(Monitor(1000))
data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=True,
Expand All @@ -154,7 +154,7 @@ def l2_norm(label, pred):
def finetune(self, X, batch_size, n_iter, optimizer, l_rate, decay, lr_scheduler=None):
def l2_norm(label, pred):
return np.mean(np.square(label-pred))/2.0
solver = Solver('sgd', momentum=0.9, wd=decay, learning_rate=l_rate, lr_scheduler=lr_scheduler)
solver = Solver(optimizer, momentum=0.9, wd=decay, learning_rate=l_rate, lr_scheduler=lr_scheduler)
solver.set_metric(mx.metric.CustomMetric(l2_norm))
solver.set_monitor(Monitor(1000))
data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=True,
Expand Down
9 changes: 7 additions & 2 deletions example/cpp/Makefile
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
CFLAGS=-I ../../include -Wall -O3 -msse3 -funroll-loops -Wno-unused-parameter -Wno-unknown-pragmas -fopenmp -I ../../mshadow -I ../../dmlc-core/include
LDFLAGS=-L ../../lib -lmxnet -lopenblas -DMSHADOW_USE_CBLAS=1 -DMSHADOW_USE_MKL=0 -DMSHADOW_USE_CUDA=1

CXX=g++

mlp: ./mlp.cpp
g++ -std=c++0x $(CFLAGS) $(LDFLAGS) -o $@ $^
$(CXX) -std=c++0x $(CFLAGS) -o $@ $^ $(LDFLAGS)

use_ndarray: ./use_ndarray.cpp
g++ -std=c++0x $(CFLAGS) $(LDFLAGS) -o $@ $^
$(CXX) -std=c++0x $(CFLAGS) -o $@ $^ $(LDFLAGS)

lint:
python2 ../../dmlc-core/scripts/lint.py mxnet "cpp" ./

clean:
rm -f mlp use_ndarray
66 changes: 66 additions & 0 deletions example/fcn-xs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
FCN-xs EXAMPLES
---------------
This folder contains the examples of image segmentation in MXNet.

## Sample results
![fcn-xs pasval_voc result](https://github.com/dmlc/web-data/blob/master/mxnet/image/fcnxs-example-result.jpg)

we have trained a simple fcn-xs model, the parameter is below:

| model | lr (fixed) | epoch |
| ---- | ----: | ---------: |
| fcn-32s | 1e-10 | 31 |
| fcn-16s | 1e-12 | 27 |
| fcn-8s | 1e-14 | 19 |

the training image number is only : 2027, and the Validation image number is: 462

## How to train fcn-xs in mxnet
#### step1: download the vgg16fc model and experiment data
* vgg16fc model : you can download the ```VGG_FC_ILSVRC_16_layers-symbol.json``` and ```VGG_FC_ILSVRC_16_layers-0074.params``` from [yun.baidu](http://pan.baidu.com/s/1bgz4PC).
this is the fully convolution style of the origin
[VGG_ILSVRC_16_layers.caffemodel](http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel), and the corresponding [VGG_ILSVRC_16_layers_deploy.prototxt](https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-vgg_ilsvrc_16_layers_deploy-prototxt), the vgg16 model has [license](http://creativecommons.org/licenses/by-nc/4.0/) for non-commercial use only.
* experiment data : you can download the ```VOC2012.rar``` from [yun.baidu](http://pan.baidu.com/s/1bgz4PC), and Extract it. the file/folder will be like:
```JPEGImages folder```, ```SegmentationClass folder```, ```train.lst```, ```val.lst```, ```test.lst```

#### step2: train fcn-xs model
* if you want to train the fcn-8s model, it's better for you trained the fcn-32s and fcn-16s model firstly.
when training the fcn-32s model, run in shell ```./run_fcnxs.sh```, the script in it is:
```shell
python -u fcn_xs.py --model=fcn32s --prefix=VGG_FC_ILSVRC_16_layers --epoch=74 --init-type=vgg16
```
* in the fcn_xs.py, you may need to change the directory ```root_dir```, ```flist_name```, ``fcnxs_model_prefix``` for your own data.
* when you train fcn-16s or fcn-8s model, you should change the code in ```run_fcnxs.sh``` corresponding, such as when train fcn-16s, comment out the fcn32s script, then it will like this:
```shell
python -u fcn_xs.py --model=fcn16s --prefix=FCN32s_VGG16 --epoch=31 --init-type=fcnxs
```
* the output log may like this(when training fcn-8s):
```c++
INFO:root:Start training with gpu(3)
INFO:root:Epoch[0] Batch [50] Speed: 1.16 samples/sec Train-accuracy=0.894318
INFO:root:Epoch[0] Batch [100] Speed: 1.11 samples/sec Train-accuracy=0.904681
INFO:root:Epoch[0] Batch [150] Speed: 1.13 samples/sec Train-accuracy=0.908053
INFO:root:Epoch[0] Batch [200] Speed: 1.12 samples/sec Train-accuracy=0.912219
INFO:root:Epoch[0] Batch [250] Speed: 1.13 samples/sec Train-accuracy=0.914238
INFO:root:Epoch[0] Batch [300] Speed: 1.13 samples/sec Train-accuracy=0.912170
INFO:root:Epoch[0] Batch [350] Speed: 1.12 samples/sec Train-accuracy=0.912080
```
## Using the pre-trained model for image segmentation
* similarly, you should firstly download the pre-trained model from [yun.baidu](http://pan.baidu.com/s/1bgz4PC), the symbol and model file is ```FCN8s_VGG16-symbol.json```, ```FCN8s_VGG16-0019.params```
* then put the image in your directory for segmentation, and change the ```img = YOUR_IMAGE_NAME``` in ```image_segmentaion.py```
* lastly, use ```image_segmentaion.py``` to segmentation one image by run in shell ```python image_segmentaion.py```, then you will get the segmentation image like the sample result above.

## Tips
* this is the whole image size training, that is to say, we do not need resize/crop the image to the same size, so the batch_size during training is set to 1.
* the fcn-xs model is baed on vgg16 model, with some crop, deconv, element-sum layer added, so the model is some big, moreover, the example is using whole image size training, if the input image is some large(such as 700*500), then it may very memory consumption, so I suggest you using the GPU with 12G memory.
* if you don't have GPU with 12G memory, maybe you shoud change the ```cut_off_size``` to be a small value when you construct your FileIter, like this:
```python
train_dataiter = FileIter(
root_dir = "./VOC2012",
flist_name = "train.lst",
cut_off_size = 400,
rgb_mean = (123.68, 116.779, 103.939),
)
```
* we are looking forward you to make this example more powerful, thanks.
122 changes: 122 additions & 0 deletions example/fcn-xs/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# pylint: skip-file
""" file iterator for pasval voc 2012"""
import mxnet as mx
import numpy as np
import sys, os
from mxnet.io import DataIter
from PIL import Image

class FileIter(DataIter):
"""FileIter object in fcn-xs example. Taking a file list file to get dataiter.
in this example, we use the whole image training for fcn-xs, that is to say
we do not need resize/crop the image to the same size, so the batch_size is
set to 1 here
Parameters
----------
root_dir : string
the root dir of image/label lie in
flist_name : string
the list file of iamge and label, every line owns the form:
index \t image_data_path \t image_label_path
cut_off_size : int
if the maximal size of one image is larger than cut_off_size, then it will
crop the image with the minimal size of that image
data_name : string
the data name used in symbol data(default data name)
label_name : string
the label name used in symbol softmax_label(default label name)
"""
def __init__(self, root_dir, flist_name,
rgb_mean = (117, 117, 117),
cut_off_size = None,
data_name = "data",
label_name = "softmax_label"):
super(FileIter, self).__init__()
self.root_dir = root_dir
self.flist_name = os.path.join(self.root_dir, flist_name)
self.mean = np.array(rgb_mean) # (R, G, B)
self.cut_off_size = cut_off_size
self.data_name = data_name
self.label_name = label_name

self.num_data = len(open(self.flist_name, 'r').readlines())
self.f = open(self.flist_name, 'r')
self.data, self.label = self._read()
self.cursor = -1

def _read(self):
"""get two list, each list contains two elements: name and nd.array value"""
_, data_img_name, label_img_name = self.f.readline().strip('\n').split("\t")
data = {}
label = {}
data[self.data_name], label[self.label_name] = self._read_img(data_img_name, label_img_name)
return list(data.items()), list(label.items())

def _read_img(self, img_name, label_name):
img = Image.open(os.path.join(self.root_dir, img_name))
label = Image.open(os.path.join(self.root_dir, label_name))
assert img.size == label.size
img = np.array(img, dtype=np.float32) # (h, w, c)
label = np.array(label) # (h, w)
if self.cut_off_size is not None:
max_hw = max(img.shape[0], img.shape[1])
min_hw = min(img.shape[0], img.shape[1])
if min_hw > self.cut_off_size:
rand_start_max = round(np.random.uniform(0, max_hw - self.cut_off_size - 1))
rand_start_min = round(np.random.uniform(0, min_hw - self.cut_off_size - 1))
if img.shape[0] == max_hw :
img = img[rand_start_max : rand_start_max + self.cut_off_size, rand_start_min : rand_start_min + self.cut_off_size]
label = label[rand_start_max : rand_start_max + self.cut_off_size, rand_start_min : rand_start_min + self.cut_off_size]
else :
img = img[rand_start_min : rand_start_min + self.cut_off_size, rand_start_max : rand_start_max + self.cut_off_size]
label = label[rand_start_min : rand_start_min + self.cut_off_size, rand_start_max : rand_start_max + self.cut_off_size]
elif max_hw > self.cut_off_size:
rand_start = round(np.random.uniform(0, max_hw - min_hw - 1))
if img.shape[0] == max_hw :
img = img[rand_start : rand_start + min_hw, :]
label = label[rand_start : rand_start + min_hw, :]
else :
img = img[:, rand_start : rand_start + min_hw]
label = label[:, rand_start : rand_start + min_hw]
reshaped_mean = self.mean.reshape(1, 1, 3)
img = img - reshaped_mean
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2) # (c, h, w)
img = np.expand_dims(img, axis=0) # (1, c, h, w)
label = np.array(label) # (h, w)
label = np.expand_dims(label, axis=0) # (1, h, w)
return (img, label)

@property
def provide_data(self):
"""The name and shape of data provided by this iterator"""
return [(k, tuple([1] + list(v.shape[1:]))) for k, v in self.data]

@property
def provide_label(self):
"""The name and shape of label provided by this iterator"""
return [(k, tuple([1] + list(v.shape[1:]))) for k, v in self.label]

def get_batch_size(self):
return 1

def reset(self):
self.cursor = -1
self.f.close()
self.f = open(self.flist_name, 'r')

def iter_next(self):
self.cursor += 1
if(self.cursor < self.num_data-1):
return True
else:
return False

def next(self):
"""return one dict which contains "data" and "label" """
if self.iter_next():
self.data, self.label = self._read()
return {self.data_name : self.data[0][1],
self.label_name : self.label[0][1]}
else:
raise StopIteration
73 changes: 73 additions & 0 deletions example/fcn-xs/fcn_xs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# pylint: skip-file
import sys, os
import argparse
import mxnet as mx
import numpy as np
import logging
import symbol_fcnxs
import init_fcnxs
from data import FileIter
from solver import Solver

logger = logging.getLogger()
logger.setLevel(logging.INFO)
ctx = mx.gpu(0)

def main():
fcnxs = symbol_fcnxs.get_fcn32s_symbol(numclass=21, workspace_default=1536)
fcnxs_model_prefix = "model_pascal/FCN32s_VGG16"
if args.model == "fcn16s":
fcnxs = symbol_fcnxs.get_fcn16s_symbol(numclass=21, workspace_default=1536)
fcnxs_model_prefix = "model_pascal/FCN16s_VGG16"
elif args.model == "fcn8s":
fcnxs = symbol_fcnxs.get_fcn8s_symbol(numclass=21, workspace_default=1536)
fcnxs_model_prefix = "model_pascal/FCN8s_VGG16"
arg_names = fcnxs.list_arguments()
_, fcnxs_args, fcnxs_auxs = mx.model.load_checkpoint(args.prefix, args.epoch)
if not args.retrain:
if args.init_type == "vgg16":
fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_vgg16(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
elif args.init_type == "fcnxs":
fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_fcnxs(ctx, fcnxs, fcnxs_args, fcnxs_auxs)
train_dataiter = FileIter(
root_dir = "./VOC2012",
flist_name = "train.lst",
# cut_off_size = 400,
rgb_mean = (123.68, 116.779, 103.939),
)
val_dataiter = FileIter(
root_dir = "./VOC2012",
flist_name = "val.lst",
rgb_mean = (123.68, 116.779, 103.939),
)
model = Solver(
ctx = ctx,
symbol = fcnxs,
begin_epoch = 0,
num_epoch = 50,
arg_params = fcnxs_args,
aux_params = fcnxs_auxs,
learning_rate = 1e-10,
momentum = 0.99,
wd = 0.0005)
model.fit(
train_data = train_dataiter,
eval_data = val_dataiter,
batch_end_callback = mx.callback.Speedometer(1, 10),
epoch_end_callback = mx.callback.do_checkpoint(fcnxs_model_prefix))

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Convert vgg16 model to vgg16fc model.')
parser.add_argument('--model', default='fcnxs',
help='The type of fcn-xs model, e.g. fcnxs, fcn16s, fcn8s.')
parser.add_argument('--prefix', default='VGG_FC_ILSVRC_16_layers',
help='The prefix(include path) of vgg16 model with mxnet format.')
parser.add_argument('--epoch', type=int, default=74,
help='The epoch number of vgg16 model.')
parser.add_argument('--init-type', default="vgg16",
help='the init type of fcn-xs model, e.g. vgg16, fcnxs')
parser.add_argument('--retrain', action='store_true', default=False,
help='true means continue training.')
args = parser.parse_args()
logging.info(args)
main()
Loading

0 comments on commit d167481

Please sign in to comment.