From 14f314c0def0f4d12f416da054c417992eb67d73 Mon Sep 17 00:00:00 2001 From: winsty Date: Tue, 29 Dec 2015 19:29:09 +0800 Subject: [PATCH 01/32] add hsl color space aug --- src/io/image_augmenter.h | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/io/image_augmenter.h b/src/io/image_augmenter.h index 53d6e6092d29..b1ce8de691c7 100644 --- a/src/io/image_augmenter.h +++ b/src/io/image_augmenter.h @@ -44,6 +44,12 @@ struct ImageAugmentParam : public dmlc::Parameter { float min_img_size; /*! \brief max image size */ float max_img_size; + /*! \brief max random in H channel */ + int random_h; + /*! \brief max random in S channel */ + int random_s; + /*! \brief max random in L channel */ + int random_l; /*! \brief rotate angle */ int rotate; /*! \brief filled color while padding */ @@ -76,6 +82,12 @@ struct ImageAugmentParam : public dmlc::Parameter { .describe("Augmentation Param: Maxmum image size after resizing."); DMLC_DECLARE_FIELD(min_img_size).set_default(0.0f) .describe("Augmentation Param: Minimum image size after resizing."); + DMLC_DECLARE_FIELD(random_h).set_default(0) + .describe("Augmentation Param: Maximum value of H channel in HSV color space."); + DMLC_DECLARE_FIELD(random_s).set_default(0) + .describe("Augmentation Param: Maximum value of S channel in HSV color space."); + DMLC_DECLARE_FIELD(random_l).set_default(0) + .describe("Augmentation Param: Maximum value of V channel in HSV color space."); DMLC_DECLARE_FIELD(rotate).set_default(-1.0f) .describe("Augmentation Param: Rotate angle."); DMLC_DECLARE_FIELD(fill_value).set_default(255) @@ -210,8 +222,32 @@ class ImageAugmenter { cv::Rect roi(x, y, param_.data_shape[2], param_.data_shape[1]); res = res(roi); } + + // color space augmentation + if (param_.random_h != 0 || param_.random_s != 0 || param_.random_l != 0) { + std::uniform_real_distribution rand_uniform(0, 1); + cvtColor(res, res, CV_BGR2HLS); + int h = rand_uniform(*prnd) * param_.random_h * 2 - param_.random_h; + int s = rand_uniform(*prnd) * param_.random_s * 2 - param_.random_s; + int l = rand_uniform(*prnd) * param_.random_l * 2 - param_.random_l; + int temp[3] = {h, l, s}; + int limit[3] = {180, 255, 255}; + for (int i = 0; i < res.rows; ++i) { + for (int j = 0; j < res.cols; ++j) { + for (int k = 0; k < 3; ++k) { + int v = res.at(i, j)[k]; + v += temp[k]; + v = std::max(0, std::min(limit[k], v)); + res.at(i, j)[k] = v; + } + } + } + cvtColor(res, res, CV_HLS2BGR); + } return res; } + + #endif private: From 7054f54eccdea70590981aae1735080ff28d12ee Mon Sep 17 00:00:00 2001 From: winsty Date: Tue, 29 Dec 2015 19:30:53 +0800 Subject: [PATCH 02/32] fix --- src/io/image_augmenter.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/io/image_augmenter.h b/src/io/image_augmenter.h index b1ce8de691c7..51586887ede2 100644 --- a/src/io/image_augmenter.h +++ b/src/io/image_augmenter.h @@ -87,7 +87,7 @@ struct ImageAugmentParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(random_s).set_default(0) .describe("Augmentation Param: Maximum value of S channel in HSV color space."); DMLC_DECLARE_FIELD(random_l).set_default(0) - .describe("Augmentation Param: Maximum value of V channel in HSV color space."); + .describe("Augmentation Param: Maximum value of L channel in HSV color space."); DMLC_DECLARE_FIELD(rotate).set_default(-1.0f) .describe("Augmentation Param: Rotate angle."); DMLC_DECLARE_FIELD(fill_value).set_default(255) From 4a9a783110b489bbbdf3acb002ff1bcb2da005f1 Mon Sep 17 00:00:00 2001 From: Mu Li Date: Tue, 29 Dec 2015 22:16:18 -0500 Subject: [PATCH 03/32] let `make clean` do clean_all to avoid frequently asked questions such as #1104 #1011 #958 #920 #721 #719 --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index ea71cd3fff86..63fa524b016f 100644 --- a/Makefile +++ b/Makefile @@ -200,11 +200,11 @@ rpkg: roxygen clean: $(RM) -r build lib bin *~ */*~ */*/*~ */*/*/*~ - -clean_all: clean cd $(DMLC_CORE); make clean; cd - cd $(PS_PATH); make clean; cd - +clean_all: clean + -include build/*.d -include build/*/*.d ifneq ($(EXTRA_OPERATORS),) From 98cd4359ff2af1490ac7c7d67b3e57ab5a1dd39f Mon Sep 17 00:00:00 2001 From: tornadomeet Date: Sat, 19 Dec 2015 13:57:06 +0800 Subject: [PATCH 04/32] add fcn-xs example for image segmentation refactor the fcn-xs example update run_fcnxs.sh add fcn-xs example for image segmentation --- example/fcn-xs/README.md | 66 ++++++++++ example/fcn-xs/data.py | 115 +++++++++++++++++ example/fcn-xs/fcn_xs.py | 72 +++++++++++ example/fcn-xs/image_segmentaion.py | 60 +++++++++ example/fcn-xs/init_fcnxs.py | 86 +++++++++++++ example/fcn-xs/run_fcnxs.sh | 11 ++ example/fcn-xs/solver.py | 126 ++++++++++++++++++ example/fcn-xs/symbol_fcnxs.py | 192 ++++++++++++++++++++++++++++ python/mxnet/callback.py | 7 +- python/mxnet/lr_scheduler.py | 2 +- python/mxnet/metric.py | 2 +- src/operator/crop-inl.h | 181 ++++++++++++++++++++++++++ src/operator/crop.cc | 29 +++++ src/operator/crop.cu | 18 +++ src/operator/softmax_output-inl.h | 18 ++- 15 files changed, 979 insertions(+), 6 deletions(-) create mode 100644 example/fcn-xs/README.md create mode 100644 example/fcn-xs/data.py create mode 100644 example/fcn-xs/fcn_xs.py create mode 100644 example/fcn-xs/image_segmentaion.py create mode 100644 example/fcn-xs/init_fcnxs.py create mode 100755 example/fcn-xs/run_fcnxs.sh create mode 100644 example/fcn-xs/solver.py create mode 100644 example/fcn-xs/symbol_fcnxs.py create mode 100644 src/operator/crop-inl.h create mode 100644 src/operator/crop.cc create mode 100644 src/operator/crop.cu diff --git a/example/fcn-xs/README.md b/example/fcn-xs/README.md new file mode 100644 index 000000000000..e970eb4a2414 --- /dev/null +++ b/example/fcn-xs/README.md @@ -0,0 +1,66 @@ +FCN-xs EXAMPLES +--------------- +This folder contains the examples of image segmentation in MXNet. + +## Sample results +![fcn-xs pasval_voc result](https://github.com/dmlc/web-data/blob/master/mxnet/image/fcnxs-example-result.jpg) + +we have trained a simple fcn-xs model, the parameter is below: + +| model | lr (fixed) | epoch | +| ---- | ----: | ---------: | +| fcn-32s | 1e-10 | 31 | +| fcn-16s | 1e-12 | 27 | +| fcn-8s | 1e-14 | 19 | + +the training image number is only : 2027, and the Validation image number is: 462 + +## How to train fcn-xs in mxnet +#### step1: download the vgg16fc model and experiment data +* vgg16fc model : you can download the ```VGG_FC_ILSVRC_16_layers-symbol.json``` and ```VGG_FC_ILSVRC_16_layers-0074.params``` from [yun.baidu](http://pan.baidu.com/s/1gerce1H). +this is the fully convolution style of the origin +[VGG_ILSVRC_16_layers.caffemodel](http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel), and the corresponding [VGG_ILSVRC_16_layers_deploy.prototxt](https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-vgg_ilsvrc_16_layers_deploy-prototxt), the vgg16 model has [license](http://creativecommons.org/licenses/by-nc/4.0/) for non-commercial use only. +* experiment data : you can download the ```VOC2012.rar``` from [yun.baidu](http://pan.baidu.com/s/1jGlOvno), and Extract it. the file/folder will be like: +```JPEGImages folder```, ```SegmentationClass folder```, ```train.lst```, ```val.lst```, ```test.lst``` + +#### step2: train fcn-xs model +* if you want to train the fcn-8s model, it's better for you trained the fcn-32s and fcn-16s model firstly. +when training the fcn-32s model, run in shell ```./run_fcnxs.sh```, the script in it is: +```shell +python -u fcn_xs.py --model=fcn32s --prefix=VGG_FC_ILSVRC_16_layers --epoch=74 --init-type=vgg16 +``` +* in the fcn_xs.py, you may need to change the directory ```root_dir```, ```flist_name```, ``fcnxs_model_prefix``` for your own data. +* when you train fcn-16s or fcn-8s model, you should change the code in ```run_fcnxs.sh``` corresponding, such as when train fcn-16s, comment out the fcn32s script, then it will like this: +```shell + python -u fcn_xs.py --model=fcn16s --prefix=FCN32s_VGG16 --epoch=31 --init-type=fcnxs +``` +* the output log may like this(when training fcn-8s): +```c++ +INFO:root:Start training with gpu(3) +INFO:root:Epoch[0] Batch [50] Speed: 1.16 samples/sec Train-accuracy=0.894318 +INFO:root:Epoch[0] Batch [100] Speed: 1.11 samples/sec Train-accuracy=0.904681 +INFO:root:Epoch[0] Batch [150] Speed: 1.13 samples/sec Train-accuracy=0.908053 +INFO:root:Epoch[0] Batch [200] Speed: 1.12 samples/sec Train-accuracy=0.912219 +INFO:root:Epoch[0] Batch [250] Speed: 1.13 samples/sec Train-accuracy=0.914238 +INFO:root:Epoch[0] Batch [300] Speed: 1.13 samples/sec Train-accuracy=0.912170 +INFO:root:Epoch[0] Batch [350] Speed: 1.12 samples/sec Train-accuracy=0.912080 +``` + +## Using the pre-trained model for image segmentation +* similarly, you should firstly download the pre-trained model from [yun.baidu](http://pan.baidu.com/s/1gerce1H), the symbol and model file is ```FCN8s_VGG16-symbol.json```, ```FCN8s_VGG16-0019.params``` +* then put the image in your directory for segmentation, and change the ```img = YOUR_IMAGE_NAME``` in ```image_segmentaion.py``` +* lastly, use ```image_segmentaion.py``` to segmentation one image by run in shell ```python image_segmentaion.py```, then you will get the segmentation image like the sample result above. + +## Tips +* this is the whole image size training, that is to say, we do not need resize/crop the image to the same size, so the batch_size during training is set to 1. +* the fcn-xs model is baed on vgg16 model, with some crop, deconv, element-sum layer added, so the model is some big, moreover, the example is using whole image size training, if the input image is some large(such as 700*500), then it may very memory consumption, so I suggest you using the GPU with 12G memory. +* if you don't have GPU with 12G memory, maybe you shoud change the ```cut_off_size``` to be a small value when you construct your FileIter, like this: +```python +train_dataiter = FileIter( + root_dir = "./VOC2012", + flist_name = "train.lst", + cut_off_size = 400, + rgb_mean = (123.68, 116.779, 103.939), + ) +``` +* we are looking forward you to make this example more powerful, thanks. diff --git a/example/fcn-xs/data.py b/example/fcn-xs/data.py new file mode 100644 index 000000000000..a33fd6c7f2e0 --- /dev/null +++ b/example/fcn-xs/data.py @@ -0,0 +1,115 @@ +# pylint: skip-file +""" file iterator for pasval voc 2012""" +import mxnet as mx +import numpy as np +import sys, os +from mxnet.io import DataIter +from skimage import io +from PIL import Image + +class FileIter(DataIter): + """FileIter object in fcn-xs example. Taking a file list file to get dataiter. + in this example, we use the whole image training for fcn-xs, that is to say + we do not need resize/crop the image to the same size, so the batch_size is + set to 1 here + Parameters + ---------- + root_dir : string + the root dir of image/label lie in + flist_name : string + the list file of iamge and label, every line owns the form: + index \t image_data_path \t image_label_path + cut_off_size : int + if the maximal size of one image is larger than cut_off_size, then it will + crop the image with the minimal size of that image + data_name : string + the data name used in symbol data(default data name) + label_name : string + the label name used in symbol softmax_label(default label name) + """ + def __init__(self, root_dir, flist_name, + rgb_mean = (117, 117, 117), + cut_off_size = None, + data_name = "data", + label_name = "softmax_label"): + super(FileIter, self).__init__() + self.root_dir = root_dir + self.flist_name = os.path.join(self.root_dir, flist_name) + self.mean = np.array(rgb_mean) # (R, G, B) + self.cut_off_size = cut_off_size + self.data_name = data_name + self.label_name = label_name + + self.num_data = len(open(self.flist_name, 'r').readlines()) + self.f = open(self.flist_name, 'r') + self.data, self.label = self._read() + self.cursor = -1 + + def _read(self): + """get two list, each list contains two elements: name and nd.array value""" + _, data_img_name, label_img_name = self.f.readline().strip('\n').split("\t") + data = {} + label = {} + data[self.data_name], label[self.label_name] = self._read_img(data_img_name, label_img_name) + return list(data.items()), list(label.items()) + + def _read_img(self, img_name, label_name): + img = Image.open(os.path.join(self.root_dir, img_name)) + label = Image.open(os.path.join(self.root_dir, label_name)) + assert img.size == label.size + img = np.array(img, dtype=np.float32) # (h, w, c) + label = np.array(label) # (h, w) + if self.cut_off_size is not None: + max_hw = max(img.shape[0], img.shape[1]) + min_hw = min(img.shape[0], img.shape[1]) + if max_hw > cut_off_size: + rand_start = round(np.random.uniform(0, max_hw - min_hw - 1)) + if img.shape[0] == max_hw : + img = img[rand_start : rand_start + min_hw, :] + label = label[rand_start : rand_start + min_hw, :] + else : + img = img[:, rand_start : rand_start + min_hw] + label = label[:, rand_start : rand_start + min_hw] + reshaped_mean = self.mean.reshape(1, 1, 3) + img = img - reshaped_mean + img = np.swapaxes(img, 0, 2) + img = np.swapaxes(img, 1, 2) # (c, h, w) + img = np.expand_dims(img, axis=0) # (1, c, h, w) or (1, h, w) + label = np.array(label) # (h, w) + label = np.expand_dims(label, axis=0) # (1, c, h, w) or (1, h, w) + return (img, label) + + @property + def provide_data(self): + """The name and shape of data provided by this iterator""" + return [(k, tuple([1] + list(v.shape[1:]))) for k, v in self.data] + + @property + def provide_label(self): + """The name and shape of label provided by this iterator""" + return [(k, tuple([1] + list(v.shape[1:]))) for k, v in self.label] + + @property + def batch_size(self): + return 1 + + def reset(self): + self.cursor = -1 + self.f.close() + self.f = open(self.flist_name, 'r') + + def iter_next(self): + self.cursor += 1 + if(self.cursor < self.num_data-1): + return True + else: + return False + + def next(self): + """return one dict which contains "data" and "label" """ + if self.iter_next(): + self.data, self.label = self._read() + return {self.data_name : self.data[0][1], + self.label_name : self.label[0][1]} + else: + raise StopIteration diff --git a/example/fcn-xs/fcn_xs.py b/example/fcn-xs/fcn_xs.py new file mode 100644 index 000000000000..5f30ebab10aa --- /dev/null +++ b/example/fcn-xs/fcn_xs.py @@ -0,0 +1,72 @@ +# pylint: skip-file +import sys, os +import argparse +import mxnet as mx +import numpy as np +import logging +import symbol_fcnxs +import init_fcnxs +from data import FileIter +from solver import Solver + +logger = logging.getLogger() +logger.setLevel(logging.INFO) +ctx = mx.gpu(0) + +def main(): + fcnxs = symbol_fcnxs.get_fcn32s_symbol(numclass=21, workspace_default=1536) + fcnxs_model_prefix = "model_pascal/FCN32s_VGG16" + if args.model == "fcn16s": + fcnxs = symbol_fcnxs.get_fcn16s_symbol(numclass=21, workspace_default=1536) + fcnxs_model_prefix = "model_pascal/FCN16s_VGG16" + elif args.model == "fcn8s": + fcnxs = symbol_fcnxs.get_fcn8s_symbol(numclass=21, workspace_default=1536) + fcnxs_model_prefix = "model_pascal/FCN8s_VGG16" + arg_names = fcnxs.list_arguments() + _, fcnxs_args, fcnxs_auxs = mx.model.load_checkpoint(args.prefix, args.epoch) + if not args.retrain: + if args.init_type == "vgg16": + fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_vgg16(ctx, fcnxs, fcnxs_args, fcnxs_auxs) + elif args.init_type == "fcnxs": + fcnxs_args, fcnxs_auxs = init_fcnxs.init_from_fcnxs(ctx, fcnxs, fcnxs_args, fcnxs_auxs) + train_dataiter = FileIter( + root_dir = "./VOC2012", + flist_name = "train.lst", + rgb_mean = (123.68, 116.779, 103.939), + ) + val_dataiter = FileIter( + root_dir = "./VOC2012", + flist_name = "val.lst", + rgb_mean = (123.68, 116.779, 103.939), + ) + model = Solver( + ctx = ctx, + symbol = fcnxs, + begin_epoch = 0, + num_epoch = 50, + arg_params = fcnxs_args, + aux_params = fcnxs_auxs, + learning_rate = 1e-10, + momentum = 0.99, + wd = 0.0005) + model.fit( + train_data = train_dataiter, + eval_data = val_dataiter, + batch_end_callback = mx.callback.Speedometer(1, 10), + epoch_end_callback = mx.callback.do_checkpoint(fcnxs_model_prefix)) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Convert vgg16 model to vgg16fc model.') + parser.add_argument('--model', default='fcnxs', + help='The type of fcn-xs model, e.g. fcnxs, fcn16s, fcn8s.') + parser.add_argument('--prefix', default='VGG_FC_ILSVRC_16_layers', + help='The prefix(include path) of vgg16 model with mxnet format.') + parser.add_argument('--epoch', type=int, default=74, + help='The epoch number of vgg16 model.') + parser.add_argument('--init-type', default="vgg16", + help='the init type of fcn-xs model, e.g. vgg16, fcnxs') + parser.add_argument('--retrain', action='store_true', default=False, + help='true means continue training.') + args = parser.parse_args() + logging.info(args) + main() diff --git a/example/fcn-xs/image_segmentaion.py b/example/fcn-xs/image_segmentaion.py new file mode 100644 index 000000000000..04510a933d1e --- /dev/null +++ b/example/fcn-xs/image_segmentaion.py @@ -0,0 +1,60 @@ +# pylint: skip-file +import numpy as np +import mxnet as mx +from PIL import Image + +pallete = [ 0,0,0, + 128,0,0, + 0,128,0, + 128,128,0, + 0,0,128, + 128,0,128, + 0,128,128, + 128,128,128, + 64,0,0, + 192,0,0, + 64,128,0, + 192,128,0, + 64,0,128, + 192,0,128, + 64,128,128, + 192,128,128, + 0,64,0, + 128,64,0, + 0,192,0, + 128,192,0, + 0,64,128 ] +img = "./person_bicycle.jpg" +seg = img.replace("jpg", "png") +model_previx = "FCN8s_VGG16" +epoch = 19 +ctx = mx.gpu(0) + +def get_data(img_path): + """get the (1, 3, h, w) np.array data for the img_path""" + mean = np.array([123.68, 116.779, 103.939]) # (R,G,B) + img = Image.open(img_path) + img = np.array(img, dtype=np.float32) + reshaped_mean = mean.reshape(1, 1, 3) + img = img - reshaped_mean + img = np.swapaxes(img, 0, 2) + img = np.swapaxes(img, 1, 2) + img = np.expand_dims(img, axis=0) + return img + +def main(): + fcn32s, fcn32s_arg_params, fcn32s_aux_params = mx.model.load_checkpoint(model_previx, epoch) + fcn32s_arg_params["data"] = mx.nd.array(get_data(img), ctx) + data_shape = fcn32s_arg_params["data"].shape + label_shape = (1, data_shape[2]*data_shape[3]) + fcn32s_arg_params["softmax_label"] = mx.nd.empty(label_shape, ctx) + exector = fcn32s.bind(ctx, fcn32s_arg_params ,args_grad=None, grad_req="null", aux_states=fcn32s_arg_params) + exector.forward(is_train=False) + output = exector.outputs[0] + out_img = np.uint8(np.squeeze(output.asnumpy().argmax(axis=1))) + out_img = Image.fromarray(out_img) + out_img.putpalette(pallete) + out_img.save(seg) + +if __name__ == "__main__": + main() diff --git a/example/fcn-xs/init_fcnxs.py b/example/fcn-xs/init_fcnxs.py new file mode 100644 index 000000000000..126805e3c137 --- /dev/null +++ b/example/fcn-xs/init_fcnxs.py @@ -0,0 +1,86 @@ +# pylint: skip-file +import mxnet as mx +import numpy as np +import sys +import logging + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +# make a bilinear interpolation kernel, return a numpy.ndarray +def upsample_filt(size): + factor = (size + 1) // 2 + if size % 2 == 1: + center = factor - 1 + else: + center = factor - 0.5 + og = np.ogrid[:size, :size] + return (1 - abs(og[0] - center) / factor) * \ + (1 - abs(og[1] - center) / factor) + +def init_from_vgg16(ctx, fcnxs_symbol, vgg16fc_args, vgg16fc_auxs): + fcnxs_args = vgg16fc_args.copy() + fcnxs_auxs = vgg16fc_auxs.copy() + for k,v in fcnxs_args.items(): + if(v.context != ctx): + fcnxs_args[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcnxs_args[k]) + for k,v in fcnxs_auxs.items(): + if(v.context != ctx): + fcnxs_auxs[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcnxs_auxs[k]) + data_shape=(1,3,500,500) + arg_names = fcnxs_symbol.list_arguments() + arg_shapes, _, _ = fcnxs_symbol.infer_shape(data=data_shape) + rest_params = dict([(x[0], mx.nd.zeros(x[1], ctx)) for x in zip(arg_names, arg_shapes) + if x[0] in ['score_weight', 'score_bias', 'score_pool4_weight', 'score_pool4_bias', \ + 'score_pool3_weight', 'score_pool3_bias']]) + fcnxs_args.update(rest_params) + deconv_params = dict([(x[0], x[1]) for x in zip(arg_names, arg_shapes) + if x[0] in ["bigscore_weight", 'score2_weight', 'score4_weight']]) + for k, v in deconv_params.items(): + filt = upsample_filt(v[3]) + initw = np.zeros(v) + initw[range(v[0]), range(v[1]), :, :] = filt # becareful here is the slice assing + fcnxs_args[k] = mx.nd.array(initw, ctx) + return fcnxs_args, fcnxs_auxs + +def init_from_fcnxs(ctx, fcnxs_symbol, fcnxs_args_from, fcnxs_auxs_from): + fcnxs_args = fcnxs_args_from.copy() + fcnxs_auxs = fcnxs_auxs_from.copy() + for k,v in fcnxs_args.items(): + if(v.context != ctx): + fcnxs_args[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcnxs_args[k]) + for k,v in fcnxs_auxs.items(): + if(v.context != ctx): + fcnxs_auxs[k] = mx.nd.zeros(v.shape, ctx) + v.copyto(fcnxs_auxs[k]) + data_shape=(1,3,500,500) + arg_names = fcnxs_symbol.list_arguments() + arg_shapes, _, _ = fcnxs_symbol.infer_shape(data=data_shape) + rest_params = {} + deconv_params = {} + # this is fcn8s init from fcn16s + if 'score_pool3_weight' in arg_names: + rest_params = dict([(x[0], mx.nd.zeros(x[1], ctx)) for x in zip(arg_names, arg_shapes) + if x[0] in ['score_pool3_bias', 'score_pool3_weight']]) + deconv_params = dict([(x[0], x[1]) for x in zip(arg_names, arg_shapes) if x[0] \ + in ["bigscore_weight", 'score4_weight']]) + # this is fcn16s init from fcn32s + elif 'score_pool4_weight' in arg_names: + rest_params = dict([(x[0], mx.nd.zeros(x[1], ctx)) for x in zip(arg_names, arg_shapes) + if x[0] in ['score_pool4_weight', 'score_pool4_bias']]) + deconv_params = dict([(x[0], x[1]) for x in zip(arg_names, arg_shapes) if x[0] \ + in ["bigscore_weight", 'score2_weight']]) + # this is fcn32s init + else: + logging.error("you are init the fcn32s model, so you should use init_from_vgg16()") + sys.exit() + fcnxs_args.update(rest_params) + for k, v in deconv_params.items(): + filt = upsample_filt(v[3]) + initw = np.zeros(v) + initw[range(v[0]), range(v[1]), :, :] = filt # becareful here is the slice assing + fcnxs_args[k] = mx.nd.array(initw, ctx) + return fcnxs_args, fcnxs_auxs diff --git a/example/fcn-xs/run_fcnxs.sh b/example/fcn-xs/run_fcnxs.sh new file mode 100755 index 000000000000..8dca9b8231c5 --- /dev/null +++ b/example/fcn-xs/run_fcnxs.sh @@ -0,0 +1,11 @@ +# train fcn-32s model +python -u fcn_xs.py --model=fcn32s --prefix=VGG_FC_ILSVRC_16_layers \ + --epoch=74 --init-type=vgg16 + +# # train fcn-16s model +# python -u fcn_xs.py --model=fcn16s --prefix=FCN32s_VGG16 \ +# --epoch=31 --init-type=fcnxs + +# # train fcn-8s model +# python -u fcn_xs.py --model=fcn8s --prefix=FCN16s_VGG16 \ +# --epoch=27 --init-type=fcnxs diff --git a/example/fcn-xs/solver.py b/example/fcn-xs/solver.py new file mode 100644 index 000000000000..edd871be1736 --- /dev/null +++ b/example/fcn-xs/solver.py @@ -0,0 +1,126 @@ +# pylint: skip-file +import numpy as np +import mxnet as mx +import time +import logging +from collections import namedtuple +from mxnet import optimizer as opt +from mxnet.optimizer import get_updater +from mxnet import metric + +# Parameter to pass to batch_end_callback +BatchEndParam = namedtuple('BatchEndParams', ['epoch', 'nbatch', 'eval_metric']) +class Solver(object): + def __init__(self, symbol, ctx=None, + begin_epoch=0, num_epoch=None, + arg_params=None, aux_params=None, + optimizer='sgd', **kwargs): + self.symbol = symbol + if ctx is None: + ctx = mx.cpu() + self.ctx = ctx + self.begin_epoch = begin_epoch + self.num_epoch = num_epoch + self.arg_params = arg_params + self.aux_params = aux_params + self.optimizer = optimizer + self.kwargs = kwargs.copy() + + def fit(self, train_data, eval_data=None, + eval_metric='acc', + grad_req='write', + epoch_end_callback=None, + batch_end_callback=None, + kvstore='local', + logger=None): + if logger is None: + logger = logging + logging.info('Start training with %s', str(self.ctx)) + arg_shapes, out_shapes, aux_shapes = self.symbol.infer_shape(data=train_data.provide_data[0][1]) + arg_names = self.symbol.list_arguments() + if grad_req != 'null': + self.grad_params = {} + for name, shape in zip(arg_names, arg_shapes): + if not (name.endswith('data') or name.endswith('label')): + self.grad_params[name] = mx.nd.zeros(shape, self.ctx) + else: + self.grad_params = None + aux_names = self.symbol.list_auxiliary_states() + self.aux_params = {k : nd.zeros(s) for k, s in zip(aux_names, aux_shapes)} + data_name = train_data.data_name + label_name = train_data.label_name + input_names = [data_name, label_name] + self.optimizer = opt.create(self.optimizer, rescale_grad=(1.0/train_data.batch_size), **(self.kwargs)) + self.updater = get_updater(self.optimizer) + eval_metric = metric.create(eval_metric) + # begin training + for epoch in range(self.begin_epoch, self.num_epoch): + nbatch = 0 + train_data.reset() + eval_metric.reset() + for data in train_data: + nbatch += 1 + label_shape = data[label_name].shape + self.arg_params[data_name] = mx.nd.array(data[data_name], self.ctx) + self.arg_params[label_name] = mx.nd.array(data[label_name].reshape(label_shape[0], \ + label_shape[1]*label_shape[2]), self.ctx) + output_names = self.symbol.list_outputs() + self.exector = self.symbol.bind(self.ctx, self.arg_params, + args_grad=self.grad_params, + grad_req=grad_req, + aux_states=self.aux_params) + assert len(self.symbol.list_arguments()) == len(self.exector.grad_arrays) + update_dict = {name: nd for name, nd in zip(self.symbol.list_arguments(), \ + self.exector.grad_arrays) if nd} + output_dict = {} + output_buff = {} + for key, arr in zip(self.symbol.list_outputs(), self.exector.outputs): + output_dict[key] = arr + output_buff[key] = mx.nd.empty(arr.shape, ctx=mx.cpu()) + self.exector.forward(is_train=True) + for key in output_dict: + output_dict[key].copyto(output_buff[key]) + self.exector.backward() + for key, arr in update_dict.items(): + if key != "bigscore_weight": + self.updater(key, arr, self.arg_params[key]) + pred_shape = self.exector.outputs[0].shape + label = mx.nd.array(data[label_name].reshape(label_shape[0], label_shape[1]*label_shape[2])) + pred = mx.nd.array(output_buff["softmax_output"].asnumpy().reshape(pred_shape[0], \ + pred_shape[1], pred_shape[2]*pred_shape[3])) + eval_metric.update([label], [pred]) + self.exector.outputs[0].wait_to_read() + batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric) + batch_end_callback(batch_end_params) + if epoch_end_callback != None: + epoch_end_callback(epoch, self.symbol, self.arg_params, self.aux_params) + name, value = eval_metric.get() + logger.info(" --->Epoch[%d] Train-%s=%f", epoch, name, value) + # evaluation + if eval_data: + logger.info(" in eval process...") + nbatch = 0 + eval_data.reset() + eval_metric.reset() + for data in eval_data: + nbatch += 1 + label_shape = data[label_name].shape + self.arg_params[data_name] = mx.nd.array(data[data_name], self.ctx) + self.arg_params[label_name] = mx.nd.array(data[label_name].reshape(label_shape[0], \ + label_shape[1]*label_shape[2]), self.ctx) + exector = self.symbol.bind(self.ctx, self.arg_params, + args_grad=self.grad_params, + grad_req=grad_req, + aux_states=self.aux_params) + cpu_output_array = mx.nd.zeros(exector.outputs[0].shape) + exector.forward(is_train=False) + exector.outputs[0].copyto(cpu_output_array) + pred_shape = cpu_output_array.shape + label = mx.nd.array(data[label_name].reshape(label_shape[0], \ + label_shape[1]*label_shape[2])) + pred = mx.nd.array(cpu_output_array.asnumpy().reshape(pred_shape[0], \ + pred_shape[1], pred_shape[2]*pred_shape[3])) + eval_metric.update([label], [pred]) + exector.outputs[0].wait_to_read() + name, value = eval_metric.get() + logger.info('batch[%d] Validation-%s=%f', nbatch, name, value) diff --git a/example/fcn-xs/symbol_fcnxs.py b/example/fcn-xs/symbol_fcnxs.py new file mode 100644 index 000000000000..2003d48081f8 --- /dev/null +++ b/example/fcn-xs/symbol_fcnxs.py @@ -0,0 +1,192 @@ +# pylint: skip-file +import mxnet as mx + +def filter_map(kernel=1, stride=1, pad=0): + # why not return (stride, (kernel-stride)/2-pad)?? + return (stride, (kernel-1)/2-pad) + +def compose_fp(fp_first, fp_second): + return (fp_first[0]*fp_second[0], fp_first[0]*fp_second[1]+fp_first[1]) + +def compose_fp_list(fp_list): + fp_out = (1.0, 0.0) + for fp in fp_list: + fp_out = compose_fp(fp_out, fp) + return fp_out + +def inv_fp(fp_in): + return (1.0/fp_in[0], -1.0*fp_in[1]/fp_in[0]) + +def offset(): + conv1_1_fp = filter_map(kernel=3, pad=100) + conv1_2_fp = conv2_1_fp = conv2_2_fp = conv3_1_fp = conv3_2_fp = conv3_3_fp \ + = conv4_1_fp = conv4_2_fp = conv4_3_fp = conv5_1_fp = conv5_2_fp \ + = conv5_3_fp = filter_map(kernel=3, pad=1) + pool1_fp = pool2_fp = pool3_fp = pool4_fp = pool5_fp = filter_map(kernel=2, stride=2) + fc6_fp = filter_map(kernel=7) + fc7_fp = score_fp = score_pool4_fp = score_pool3_fp = filter_map() + # for fcn-32s + fcn32s_upscore_fp = inv_fp(filter_map(kernel=64, stride=32)) + fcn32s_upscore_list = [conv1_1_fp, conv1_2_fp, pool1_fp, conv2_1_fp, conv2_2_fp, + pool2_fp, conv3_1_fp, conv3_2_fp, conv3_3_fp, pool3_fp, + conv4_1_fp, conv4_2_fp, conv4_3_fp, pool4_fp, conv5_1_fp, + conv5_2_fp, conv5_3_fp, pool5_fp, fc6_fp, fc7_fp, score_fp, + fcn32s_upscore_fp] + crop = {} + crop["fcn32s_upscore"] = (-int(round(compose_fp_list(fcn32s_upscore_list)[1])), + -int(round(compose_fp_list(fcn32s_upscore_list)[1]))) + # for fcn-16s + score2_fp = inv_fp(filter_map(kernel=4, stride=2)) + fcn16s_upscore_fp = inv_fp(filter_map(kernel=32, stride=16)) + score_pool4c_fp_list = [inv_fp(score2_fp), inv_fp(score_fp), inv_fp(fc7_fp), inv_fp(fc6_fp), + inv_fp(pool5_fp), inv_fp(conv5_3_fp), inv_fp(conv5_2_fp), + inv_fp(conv5_1_fp), score_pool4_fp] + crop["score_pool4c"] = (-int(round(compose_fp_list(score_pool4c_fp_list)[1])), + -int(round(compose_fp_list(score_pool4c_fp_list)[1]))) + fcn16s_upscore_list = [conv1_1_fp, conv1_2_fp, pool1_fp, conv2_1_fp, conv2_2_fp, + pool2_fp, conv3_1_fp, conv3_2_fp, conv3_3_fp, pool3_fp, + conv4_1_fp, conv4_2_fp, conv4_3_fp, pool4_fp, score_pool4_fp, + inv_fp((1, -crop["score_pool4c"][0])), fcn16s_upscore_fp] + crop["fcn16s_upscore"] = (-int(round(compose_fp_list(fcn16s_upscore_list)[1])), + -int(round(compose_fp_list(fcn16s_upscore_list)[1]))) + # for fcn-8s + score4_fp = inv_fp(filter_map(kernel=4, stride=2)) + fcn8s_upscore_fp = inv_fp(filter_map(kernel=16, stride=8)) + score_pool3c_fp_list = [inv_fp(score4_fp), (1, -crop["score_pool4c"][0]), inv_fp(score_pool4_fp), + inv_fp(pool4_fp), inv_fp(conv4_3_fp), inv_fp(conv4_2_fp), + inv_fp(conv4_1_fp), score_pool3_fp, score_pool3_fp] + crop["score_pool3c"] = (-int(round(compose_fp_list(score_pool3c_fp_list)[1])), + -int(round(compose_fp_list(score_pool3c_fp_list)[1]))) + fcn8s_upscore_list = [conv1_1_fp, conv1_2_fp, pool1_fp, conv2_1_fp, conv2_2_fp, pool2_fp, + conv3_1_fp, conv3_2_fp, conv3_3_fp, pool3_fp, score_pool3_fp, + inv_fp((1, -crop["score_pool3c"][0])), fcn8s_upscore_fp] + crop["fcn8s_upscore"] = (-int(round(compose_fp_list(fcn8s_upscore_list)[1])), + -int(round(compose_fp_list(fcn8s_upscore_list)[1]))) + return crop + +def vgg16_pool3(input, workspace_default=1024): + # group 1 + conv1_1 = mx.symbol.Convolution(data=input, kernel=(3, 3), pad=(100, 100), num_filter=64, + workspace=workspace_default, name="conv1_1") + relu1_1 = mx.symbol.Activation(data=conv1_1, act_type="relu", name="relu1_1") + conv1_2 = mx.symbol.Convolution(data=relu1_1, kernel=(3, 3), pad=(1, 1), num_filter=64, + workspace=workspace_default, name="conv1_2") + relu1_2 = mx.symbol.Activation(data=conv1_2, act_type="relu", name="relu1_2") + pool1 = mx.symbol.Pooling(data=relu1_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool1") + # group 2 + conv2_1 = mx.symbol.Convolution(data=pool1, kernel=(3, 3), pad=(1, 1), num_filter=128, + workspace=workspace_default, name="conv2_1") + relu2_1 = mx.symbol.Activation(data=conv2_1, act_type="relu", name="relu2_1") + conv2_2 = mx.symbol.Convolution(data=relu2_1, kernel=(3, 3), pad=(1, 1), num_filter=128, + workspace=workspace_default, name="conv2_2") + relu2_2 = mx.symbol.Activation(data=conv2_2, act_type="relu", name="relu2_2") + pool2 = mx.symbol.Pooling(data=relu2_2, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool2") + # group 3 + conv3_1 = mx.symbol.Convolution(data=pool2, kernel=(3, 3), pad=(1, 1), num_filter=256, + workspace=workspace_default, name="conv3_1") + relu3_1 = mx.symbol.Activation(data=conv3_1, act_type="relu", name="relu3_1") + conv3_2 = mx.symbol.Convolution(data=relu3_1, kernel=(3, 3), pad=(1, 1), num_filter=256, + workspace=workspace_default, name="conv3_2") + relu3_2 = mx.symbol.Activation(data=conv3_2, act_type="relu", name="relu3_2") + conv3_3 = mx.symbol.Convolution(data=relu3_2, kernel=(3, 3), pad=(1, 1), num_filter=256, + workspace=workspace_default, name="conv3_3") + relu3_3 = mx.symbol.Activation(data=conv3_3, act_type="relu", name="relu3_3") + pool3 = mx.symbol.Pooling(data=relu3_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool3") + return pool3 + +def vgg16_pool4(input, workspace_default=1024): + # group 4 + conv4_1 = mx.symbol.Convolution(data=input, kernel=(3, 3), pad=(1, 1), num_filter=512, + workspace=workspace_default, name="conv4_1") + relu4_1 = mx.symbol.Activation(data=conv4_1, act_type="relu", name="relu4_1") + conv4_2 = mx.symbol.Convolution(data=relu4_1, kernel=(3, 3), pad=(1, 1), num_filter=512, + workspace=workspace_default, name="conv4_2") + relu4_2 = mx.symbol.Activation(data=conv4_2, act_type="relu", name="relu4_2") + conv4_3 = mx.symbol.Convolution(data=relu4_2, kernel=(3, 3), pad=(1, 1), num_filter=512, + workspace=workspace_default, name="conv4_3") + relu4_3 = mx.symbol.Activation(data=conv4_3, act_type="relu", name="relu4_3") + pool4 = mx.symbol.Pooling(data=relu4_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool4") + return pool4 + +def vgg16_score(input, numclass, workspace_default=1024): + # group 5 + conv5_1 = mx.symbol.Convolution(data=input, kernel=(3, 3), pad=(1, 1), num_filter=512, + workspace=workspace_default, name="conv5_1") + relu5_1 = mx.symbol.Activation(data=conv5_1, act_type="relu", name="relu5_1") + conv5_2 = mx.symbol.Convolution(data=relu5_1, kernel=(3, 3), pad=(1, 1), num_filter=512, + workspace=workspace_default, name="conv5_2") + relu5_2 = mx.symbol.Activation(data=conv5_2, act_type="relu", name="conv1_2") + conv5_3 = mx.symbol.Convolution(data=relu5_2, kernel=(3, 3), pad=(1, 1), num_filter=512, + workspace=workspace_default, name="conv5_3") + relu5_3 = mx.symbol.Activation(data=conv5_3, act_type="relu", name="relu5_3") + pool5 = mx.symbol.Pooling(data=relu5_3, pool_type="max", kernel=(2, 2), stride=(2,2), name="pool5") + # group 6 + fc6 = mx.symbol.Convolution(data=pool5, kernel=(7, 7), num_filter=4096, + workspace=workspace_default, name="fc6") + relu6 = mx.symbol.Activation(data=fc6, act_type="relu", name="relu6") + drop6 = mx.symbol.Dropout(data=relu6, p=0.5, name="drop6") + # group 7 + fc7 = mx.symbol.Convolution(data=drop6, kernel=(1, 1), num_filter=4096, + workspace=workspace_default, name="fc7") + relu7 = mx.symbol.Activation(data=fc7, act_type="relu", name="relu7") + drop7 = mx.symbol.Dropout(data=relu7, p=0.5, name="drop7") + # group 8 + score = mx.symbol.Convolution(data=drop7, kernel=(1, 1), num_filter=numclass, + workspace=workspace_default, name="score") + return score + +def fcnxs_score(input, crop, offset, kernel=(64,64), stride=(32,32), numclass=21, workspace_default=1024): + # score out + bigscore = mx.symbol.Deconvolution(data=input, kernel=kernel, stride=stride, num_filter=numclass, + workspace=workspace_default, name="bigscore") + upscore = mx.symbol.Crop(data=bigscore, crop_like=crop, offset=offset, name="upscore") + softmax = mx.symbol.SoftmaxOutput(data=upscore, multi_output=True, use_ignore=True, ignore_label=255, name="softmax") + return softmax + +def get_fcn32s_symbol(numclass=21, workspace_default=1024): + data = mx.symbol.Variable(name="data") + pool3 = vgg16_pool3(data, workspace_default) + pool4 = vgg16_pool4(pool3, workspace_default) + score = vgg16_score(pool4, numclass, workspace_default) + softmax = fcnxs_score(score, data, offset()["fcn32s_upscore"], (64,64), (32,32), numclass, workspace_default) + return softmax + +def get_fcn16s_symbol(numclass=21, workspace_default=1024): + data = mx.symbol.Variable(name="data") + pool3 = vgg16_pool3(data, workspace_default) + pool4 = vgg16_pool4(pool3, workspace_default) + score = vgg16_score(pool4, numclass, workspace_default) + # score 2X + score2 = mx.symbol.Deconvolution(data=score, kernel=(4, 4), stride=(2, 2), num_filter=numclass, + workspace=workspace_default, name="score2") # 2X + score_pool4 = mx.symbol.Convolution(data=pool4, kernel=(1, 1), num_filter=numclass, + workspace=workspace_default, name="score_pool4") + score_pool4c = mx.symbol.Crop(data=score_pool4, crop_like=score2, + offset=offset()["score_pool4c"], name="score_pool4c") + score_fused = mx.symbol.ElementWiseSum(*[score2, score_pool4c], name='score_fused') + softmax = fcnxs_score(score_fused, data, offset()["fcn16s_upscore"], (32, 32), (16, 16), numclass, workspace_default) + return softmax + +def get_fcn8s_symbol(numclass=21, workspace_default=1024): + data = mx.symbol.Variable(name="data") + pool3 = vgg16_pool3(data, workspace_default) + pool4 = vgg16_pool4(pool3, workspace_default) + score = vgg16_score(pool4, numclass, workspace_default) + # score 2X + score2 = mx.symbol.Deconvolution(data=score, kernel=(4, 4), stride=(2, 2),num_filter=21, + workspace=workspace_default, name="score2") # 2X + score_pool4 = mx.symbol.Convolution(data=pool4, kernel=(1, 1), num_filter=21, + workspace=workspace_default, name="score_pool4") + score_pool4c = mx.symbol.Crop(data=score_pool4, crop_like=score2, + offset=offset()["score_pool4c"], name="score_pool4c") + score_fused = mx.symbol.ElementWiseSum(*[score2, score_pool4c], name='score_fused') + # score 4X + score4 = mx.symbol.Deconvolution(data=score_fused, kernel=(4, 4), stride=(2, 2),num_filter=21, + workspace=workspace_default, name="score4") # 4X + score_pool3 = mx.symbol.Convolution(data=pool3, kernel=(1, 1), num_filter=21, + workspace=workspace_default, name="score_pool3") + score_pool3c = mx.symbol.Crop(data=score_pool3, crop_like=score4, + offset=offset()["score_pool3c"], name="score_pool3c") + score_final = mx.symbol.ElementWiseSum(*[score4, score_pool3c], name='score_final') + softmax = fcnxs_score(score_final, data, offset()["fcn8s_upscore"], (16, 16), (8, 8), numclass, workspace_default) + return softmax diff --git a/python/mxnet/callback.py b/python/mxnet/callback.py index 8d08e40ba7d3..913772a91e54 100644 --- a/python/mxnet/callback.py +++ b/python/mxnet/callback.py @@ -76,7 +76,12 @@ def __call__(self, param): if self.init: if count % self.frequent == 0: speed = self.frequent * self.batch_size / (time.time() - self.tic) - logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec", + if param.eval_metric is not None: + name, value = param.eval_metric.get() + logging.info("Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec\tTrain-%s=%f", + param.epoch, count, speed, name, value) + else: + logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec", param.epoch, count, speed) self.tic = time.time() else: diff --git a/python/mxnet/lr_scheduler.py b/python/mxnet/lr_scheduler.py index c008f058ab2a..e40e146a0af8 100644 --- a/python/mxnet/lr_scheduler.py +++ b/python/mxnet/lr_scheduler.py @@ -71,6 +71,6 @@ def __call__(self, num_update): if num_update > self.count + self.step: self.count += self.step self.base_lr *= self.factor - logging.info("Update[%d]: Change learning rate to %.5f", + logging.info("Update[%d]: Change learning rate to %0.5e", num_update, self.base_lr) return self.base_lr diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 8e3efe511c0c..4cb807e7232c 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -56,7 +56,7 @@ def update(self, labels, preds): if label.shape[0] < pred_label.shape[0]: raise Exception("Predict label is more than data label? ") self.sum_metric += numpy.sum(pred_label == label[:pred_label.shape[0]]) - num_inst = pred_label.shape[0] + num_inst = pred_label.size self.num_inst += num_inst class MAE(EvalMetric): diff --git a/src/operator/crop-inl.h b/src/operator/crop-inl.h new file mode 100644 index 000000000000..750ab833da3c --- /dev/null +++ b/src/operator/crop-inl.h @@ -0,0 +1,181 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file crop-inl.h + * \brief + * \author Wei Wu +*/ +#ifndef MXNET_OPERATOR_CROP_INL_H_ +#define MXNET_OPERATOR_CROP_INL_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" + +namespace mxnet { +namespace op { + +namespace crop_enum { +enum CropOpInputs {kData, kCropLike}; +enum CropOpOutputs {kOut}; +} // namespace crop_enum + +struct CropParam : public dmlc::Parameter { + TShape offset; + bool center_crop; + DMLC_DECLARE_PARAMETER(CropParam) { + int shape[] = {0, 0}; + DMLC_DECLARE_FIELD(offset).set_default(TShape(shape, shape + 2)) + .describe("corp offset coordinate: (y, x)"); + DMLC_DECLARE_FIELD(center_crop).set_default(false) + .describe("If set to true, then it will use be the center_crop," + "or it will crop using the shape of crop_like"); + } +}; // struct CropParam + +template +class CropOp : public Operator { + public: + explicit CropOp(CropParam param) { + this->param_ = param; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(static_cast(in_data.size()), 2); + CHECK_EQ(out_data.size(), 1); + CHECK_EQ(req[crop_enum::kOut], kWriteTo); + Stream *s = ctx.get_stream(); + Tensor data = in_data[crop_enum::kData].get(s); + Tensor out = out_data[crop_enum::kOut].get(s); + offset_hw_ = InferCropOfferset(data.shape_, out.shape_); + out = crop(data, Shape2(out.size(2), out.size(3)), offset_hw_[0], offset_hw_[1]); + } + + // because the crop_like input is only used with it's shape, so we should be + // careful setting its backwrd grad value to zeros, so that it will not hurt + // the connection of crop_like. + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_states) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_grad.size(), 2) << in_grad.size(); + CHECK_EQ(out_grad.size(), 1) << out_grad.size(); + Stream *s = ctx.get_stream(); + Tensor grad = out_grad[crop_enum::kOut].get(s); + Tensor gdata = in_grad[crop_enum::kData].get(s); + Tensor gcrop_like = in_grad[crop_enum::kCropLike].get(s); + gcrop_like = (real_t)0.0f; + offset_hw_ = InferCropOfferset(gdata.shape_, grad.shape_); + gdata = (real_t)0.0f; + slice<3>(slice<2>(gdata, offset_hw_[0], offset_hw_[0]+grad.size(2)), + offset_hw_[1], offset_hw_[1]+grad.size(3)) = grad; + } + + private: + CropParam param_; + std::vector offset_hw_; + std::vector InferCropOfferset(const mshadow::Shape<4> &data_shape, + const mshadow::Shape<4> &out_shape) { + std::vector offset_hw; + CHECK_GE(data_shape[2], out_shape[2]) << + "data_shape'height should be larger than that of out_shape"; + CHECK_GE(data_shape[3], out_shape[3]) << + "data_shape'weight should be larger than that of out_shape"; + if (param_.center_crop) { + offset_hw.push_back(static_cast((data_shape[2]-out_shape[2])/2)); + offset_hw.push_back(static_cast((data_shape[3]-out_shape[3])/2)); + } else { + CHECK_GE(static_cast(param_.offset[0]), 0) << + "offset[0] should be larger than 0"; + CHECK_LE(static_cast(param_.offset[0]), data_shape[2]-out_shape[2]) << + "offset[0] should be less than the residual space of height"; + CHECK_GE(static_cast(param_.offset[1]), 0) << + "offset[1] should be larger than 0"; + CHECK_LE(static_cast(param_.offset[1]), data_shape[3]-out_shape[3]) << + "offset[1] should be less than the residual space of width"; + offset_hw.push_back(static_cast(param_.offset[0])); + offset_hw.push_back(static_cast(param_.offset[1])); + } + return offset_hw; + } +}; // class CropOp + +template +Operator *CreateOp(CropParam param); + +#if DMLC_USE_CXX11 +class CropProp : public OperatorProperty { + public: + void Init(const std::vector >& kwargs) override { + param_.Init(kwargs); + } + + std::map GetParams() const override { + return param_.__DICT__(); + } + + std::vector ListArguments() const override { + return {"data", "crop_like"}; + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + using namespace mshadow; + CHECK_EQ(in_shape->size(), 2) << "Input:[data, crop_like]"; + TShape data_shape = in_shape->at(crop_enum::kData); + if (data_shape.ndim() == 0) return false; + CHECK_EQ(data_shape.ndim(), 4) << \ + "Input data should be 4D in batch-num_filter-y-x"; + TShape crop_shape = in_shape->at(crop_enum::kCropLike); + if (crop_shape.ndim() == 0) return false; + CHECK_EQ(crop_shape.ndim(), 4) << \ + "Input crop_like should be 4D in batch-num_filter/batch-num_channel-y-x"; + out_shape->clear(); + data_shape[2] = crop_shape[2]; + data_shape[3] = crop_shape[3]; + out_shape->push_back(data_shape); + return true; + } + + OperatorProperty* Copy() const override { + auto ptr = new CropProp(); + ptr->param_ = param_; + return ptr; + } + + std::string TypeString() const override { + return "Crop"; + } + + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + return out_grad; + } + + Operator* CreateOperator(Context ctx) const override; + + private: + CropParam param_; +}; // class CropProp +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_CROP_INL_H_ diff --git a/src/operator/crop.cc b/src/operator/crop.cc new file mode 100644 index 000000000000..5a3315c24d63 --- /dev/null +++ b/src/operator/crop.cc @@ -0,0 +1,29 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file concat.cc + * \brief + * \author Wei Wu +*/ + +#include "./crop-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator* CreateOp(CropParam param) { + return new CropOp(param); +} + +Operator* CropProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp, param_); +} + +DMLC_REGISTER_PARAMETER(CropParam); + +MXNET_REGISTER_OP_PROPERTY(Crop, CropProp) +.add_argument("data", "Symbol", "Input data to the CropOp.") +.add_argument("crop_like", "Symbol", "crop_like data to the CropOp.") +.add_arguments(CropParam::__FIELDS__()) +.describe("Crop the 2th and 3th dim of input data, with the corresponding size of crop_like."); +} // namespace op +} // namespace mxnet diff --git a/src/operator/crop.cu b/src/operator/crop.cu new file mode 100644 index 000000000000..64f8cb219f30 --- /dev/null +++ b/src/operator/crop.cu @@ -0,0 +1,18 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file concat.cu + * \brief + * \author Wei Wu +*/ + +#include "./crop-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator* CreateOp(CropParam param) { + return new CropOp(param); +} + +} // namespace op +} // namespace mxnet diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h index 60877a6b0c3c..fb026df72e55 100644 --- a/src/operator/softmax_output-inl.h +++ b/src/operator/softmax_output-inl.h @@ -2,7 +2,7 @@ * Copyright (c) 2015 by Contributors * \file softmax_output-inl.h * \brief - * \author Junyuan Xie + * \author Bing Xu */ #ifndef MXNET_OPERATOR_SOFTMAX_OUTPUT_INL_H_ #define MXNET_OPERATOR_SOFTMAX_OUTPUT_INL_H_ @@ -27,14 +27,22 @@ enum SoftmaxOutputOpOutputs {kOut}; struct SoftmaxOutputParam : public dmlc::Parameter { float grad_scale; + float ignore_label; bool multi_output; + bool use_ignore; DMLC_DECLARE_PARAMETER(SoftmaxOutputParam) { DMLC_DECLARE_FIELD(grad_scale).set_default(1.0f) .describe("Scale the gradient by a float factor"); + DMLC_DECLARE_FIELD(ignore_label).set_default(-1.0f) + .describe("the ignore_label will not work in backward, and this only" + "be used when multi_output=true"); DMLC_DECLARE_FIELD(multi_output).set_default(false) .describe("If set to true, for a (n,k,x_1,..,x_n) dimensional" "input tensor, softmax will generate n*x_1*...*x_n output, each" "has k classes"); + DMLC_DECLARE_FIELD(use_ignore).set_default(false) + .describe("If set to true, the ignore_label value will not contributor" + "to the backward gradient"); }; }; @@ -88,8 +96,12 @@ class SoftmaxOutputOp : public Operator { Tensor label = in_data[softmaxout_enum::kLabel].FlatTo2D(s); Tensor out = out_data[softmaxout_enum::kOut].get_with_shape(s3, s); Tensor grad = in_grad[softmaxout_enum::kData].get_with_shape(s3, s); - SoftmaxGrad(grad, out, label); - grad *= param_.grad_scale/s3[2]; + if (param_.use_ignore) { + SoftmaxGrad(grad, out, label, static_cast(param_.ignore_label)); + } else { + SoftmaxGrad(grad, out, label); + } + grad *= param_.grad_scale; } else { Tensor label = in_data[softmaxout_enum::kLabel].get(s); Tensor out = out_data[softmaxout_enum::kOut].FlatTo2D(s); From cc0bc333b8dcad49fd6aafa51eee2ccc1d6c3038 Mon Sep 17 00:00:00 2001 From: tornadomeet Date: Mon, 21 Dec 2015 10:58:27 +0800 Subject: [PATCH 05/32] fix attribute set problem --- example/fcn-xs/data.py | 3 +-- example/fcn-xs/solver.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/example/fcn-xs/data.py b/example/fcn-xs/data.py index a33fd6c7f2e0..96eb44a44192 100644 --- a/example/fcn-xs/data.py +++ b/example/fcn-xs/data.py @@ -89,8 +89,7 @@ def provide_label(self): """The name and shape of label provided by this iterator""" return [(k, tuple([1] + list(v.shape[1:]))) for k, v in self.label] - @property - def batch_size(self): + def get_batch_size(self): return 1 def reset(self): diff --git a/example/fcn-xs/solver.py b/example/fcn-xs/solver.py index edd871be1736..953e0a986fd2 100644 --- a/example/fcn-xs/solver.py +++ b/example/fcn-xs/solver.py @@ -50,7 +50,7 @@ def fit(self, train_data, eval_data=None, data_name = train_data.data_name label_name = train_data.label_name input_names = [data_name, label_name] - self.optimizer = opt.create(self.optimizer, rescale_grad=(1.0/train_data.batch_size), **(self.kwargs)) + self.optimizer = opt.create(self.optimizer, rescale_grad=(1.0/train_data.get_batch_size()), **(self.kwargs)) self.updater = get_updater(self.optimizer) eval_metric = metric.create(eval_metric) # begin training From f890b79aa80827b92198f89d659ffcfd966f3dca Mon Sep 17 00:00:00 2001 From: tornadomeet Date: Mon, 21 Dec 2015 14:04:36 +0800 Subject: [PATCH 06/32] fix num_filter --- example/fcn-xs/symbol_fcnxs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/example/fcn-xs/symbol_fcnxs.py b/example/fcn-xs/symbol_fcnxs.py index 2003d48081f8..0ef6bc363c8b 100644 --- a/example/fcn-xs/symbol_fcnxs.py +++ b/example/fcn-xs/symbol_fcnxs.py @@ -173,17 +173,17 @@ def get_fcn8s_symbol(numclass=21, workspace_default=1024): pool4 = vgg16_pool4(pool3, workspace_default) score = vgg16_score(pool4, numclass, workspace_default) # score 2X - score2 = mx.symbol.Deconvolution(data=score, kernel=(4, 4), stride=(2, 2),num_filter=21, + score2 = mx.symbol.Deconvolution(data=score, kernel=(4, 4), stride=(2, 2),num_filter=numclass, workspace=workspace_default, name="score2") # 2X - score_pool4 = mx.symbol.Convolution(data=pool4, kernel=(1, 1), num_filter=21, + score_pool4 = mx.symbol.Convolution(data=pool4, kernel=(1, 1), num_filter=numclass, workspace=workspace_default, name="score_pool4") score_pool4c = mx.symbol.Crop(data=score_pool4, crop_like=score2, offset=offset()["score_pool4c"], name="score_pool4c") score_fused = mx.symbol.ElementWiseSum(*[score2, score_pool4c], name='score_fused') # score 4X - score4 = mx.symbol.Deconvolution(data=score_fused, kernel=(4, 4), stride=(2, 2),num_filter=21, + score4 = mx.symbol.Deconvolution(data=score_fused, kernel=(4, 4), stride=(2, 2),num_filter=numclass, workspace=workspace_default, name="score4") # 4X - score_pool3 = mx.symbol.Convolution(data=pool3, kernel=(1, 1), num_filter=21, + score_pool3 = mx.symbol.Convolution(data=pool3, kernel=(1, 1), num_filter=numclass, workspace=workspace_default, name="score_pool3") score_pool3c = mx.symbol.Crop(data=score_pool3, crop_like=score4, offset=offset()["score_pool3c"], name="score_pool3c") From 3a66df223277a13fc88dbc3553cf87244403903a Mon Sep 17 00:00:00 2001 From: tornadomeet Date: Mon, 21 Dec 2015 15:28:18 +0800 Subject: [PATCH 07/32] fix lint --- python/mxnet/callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/callback.py b/python/mxnet/callback.py index 913772a91e54..c6f466b22269 100644 --- a/python/mxnet/callback.py +++ b/python/mxnet/callback.py @@ -79,10 +79,10 @@ def __call__(self, param): if param.eval_metric is not None: name, value = param.eval_metric.get() logging.info("Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec\tTrain-%s=%f", - param.epoch, count, speed, name, value) + param.epoch, count, speed, name, value) else: logging.info("Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec", - param.epoch, count, speed) + param.epoch, count, speed) self.tic = time.time() else: self.init = True From bf90f63a30556ee0ec5fd8a3ec267431a34a6745 Mon Sep 17 00:00:00 2001 From: tornadomeet Date: Wed, 23 Dec 2015 14:35:12 +0800 Subject: [PATCH 08/32] update name --- example/fcn-xs/image_segmentaion.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/example/fcn-xs/image_segmentaion.py b/example/fcn-xs/image_segmentaion.py index 04510a933d1e..56c7482fcb81 100644 --- a/example/fcn-xs/image_segmentaion.py +++ b/example/fcn-xs/image_segmentaion.py @@ -43,12 +43,12 @@ def get_data(img_path): return img def main(): - fcn32s, fcn32s_arg_params, fcn32s_aux_params = mx.model.load_checkpoint(model_previx, epoch) - fcn32s_arg_params["data"] = mx.nd.array(get_data(img), ctx) - data_shape = fcn32s_arg_params["data"].shape + fcnxs, fcnxs_args, fcnxs_auxs = mx.model.load_checkpoint(model_previx, epoch) + fcnxs_args["data"] = mx.nd.array(get_data(img), ctx) + data_shape = fcnxs_args["data"].shape label_shape = (1, data_shape[2]*data_shape[3]) - fcn32s_arg_params["softmax_label"] = mx.nd.empty(label_shape, ctx) - exector = fcn32s.bind(ctx, fcn32s_arg_params ,args_grad=None, grad_req="null", aux_states=fcn32s_arg_params) + fcnxs_args["softmax_label"] = mx.nd.empty(label_shape, ctx) + exector = fcnxs.bind(ctx, fcnxs_args ,args_grad=None, grad_req="null", aux_states=fcnxs_args) exector.forward(is_train=False) output = exector.outputs[0] out_img = np.uint8(np.squeeze(output.asnumpy().argmax(axis=1))) From b8b9700fcdb1d7fb4c79a3d271da178639453b8a Mon Sep 17 00:00:00 2001 From: wuwei Date: Wed, 30 Dec 2015 12:33:24 +0800 Subject: [PATCH 09/32] update for pr review --- dmlc-core | 2 +- example/fcn-xs/data.py | 16 +++++++++--- example/fcn-xs/fcn_xs.py | 7 ++++++ example/fcn-xs/init_fcnxs.py | 3 +++ example/fcn-xs/run_fcnxs.sh | 12 ++++----- example/fcn-xs/symbol_fcnxs.py | 17 ++++++------- ps-lite | 2 +- src/operator/crop-inl.h | 46 ++++++++++++++++++++++++++++------ src/operator/crop.cc | 7 +++--- 9 files changed, 79 insertions(+), 33 deletions(-) diff --git a/dmlc-core b/dmlc-core index 27013a86f8b8..a9b3320d2c6b 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 27013a86f8b8fd8bb9ebf2253928436e0eb38e13 +Subproject commit a9b3320d2c6b29506139784b877142c9ee78caaf diff --git a/example/fcn-xs/data.py b/example/fcn-xs/data.py index 96eb44a44192..dcc958ea481a 100644 --- a/example/fcn-xs/data.py +++ b/example/fcn-xs/data.py @@ -4,7 +4,6 @@ import numpy as np import sys, os from mxnet.io import DataIter -from skimage import io from PIL import Image class FileIter(DataIter): @@ -62,7 +61,16 @@ def _read_img(self, img_name, label_name): if self.cut_off_size is not None: max_hw = max(img.shape[0], img.shape[1]) min_hw = min(img.shape[0], img.shape[1]) - if max_hw > cut_off_size: + if min_hw > self.cut_off_size: + rand_start_max = round(np.random.uniform(0, max_hw - self.cut_off_size - 1)) + rand_start_min = round(np.random.uniform(0, min_hw - self.cut_off_size - 1)) + if img.shape[0] == max_hw : + img = img[rand_start_max : rand_start_max + self.cut_off_size, rand_start_min : rand_start_min + self.cut_off_size] + label = label[rand_start_max : rand_start_max + self.cut_off_size, rand_start_min : rand_start_min + self.cut_off_size] + else : + img = img[rand_start_min : rand_start_min + self.cut_off_size, rand_start_max : rand_start_max + self.cut_off_size] + label = label[rand_start_min : rand_start_min + self.cut_off_size, rand_start_max : rand_start_max + self.cut_off_size] + elif max_hw > self.cut_off_size: rand_start = round(np.random.uniform(0, max_hw - min_hw - 1)) if img.shape[0] == max_hw : img = img[rand_start : rand_start + min_hw, :] @@ -74,9 +82,9 @@ def _read_img(self, img_name, label_name): img = img - reshaped_mean img = np.swapaxes(img, 0, 2) img = np.swapaxes(img, 1, 2) # (c, h, w) - img = np.expand_dims(img, axis=0) # (1, c, h, w) or (1, h, w) + img = np.expand_dims(img, axis=0) # (1, c, h, w) label = np.array(label) # (h, w) - label = np.expand_dims(label, axis=0) # (1, c, h, w) or (1, h, w) + label = np.expand_dims(label, axis=0) # (1, h, w) return (img, label) @property diff --git a/example/fcn-xs/fcn_xs.py b/example/fcn-xs/fcn_xs.py index 5f30ebab10aa..01344a7b123f 100644 --- a/example/fcn-xs/fcn_xs.py +++ b/example/fcn-xs/fcn_xs.py @@ -1,6 +1,11 @@ # pylint: skip-file import sys, os import argparse +# mxnet_train = "/home/work/wuwei/tools/mxnet/lib/python2.7/site-packages/mxnet-0.5.0-py2.7.egg" +mxnet_train = "/home/work/wuwei/.local/lib/python2.7/site-packages/mxnet-0.5.0-py2.7.egg" +if mxnet_train in sys.path: + sys.path.remove(mxnet_train) +sys.path.insert(0, mxnet_train) import mxnet as mx import numpy as np import logging @@ -32,11 +37,13 @@ def main(): train_dataiter = FileIter( root_dir = "./VOC2012", flist_name = "train.lst", + # cut_off_size = 400, rgb_mean = (123.68, 116.779, 103.939), ) val_dataiter = FileIter( root_dir = "./VOC2012", flist_name = "val.lst", + # cut_off_size = 400, rgb_mean = (123.68, 116.779, 103.939), ) model = Solver( diff --git a/example/fcn-xs/init_fcnxs.py b/example/fcn-xs/init_fcnxs.py index 126805e3c137..69295ce6be68 100644 --- a/example/fcn-xs/init_fcnxs.py +++ b/example/fcn-xs/init_fcnxs.py @@ -46,6 +46,9 @@ def init_from_vgg16(ctx, fcnxs_symbol, vgg16fc_args, vgg16fc_auxs): return fcnxs_args, fcnxs_auxs def init_from_fcnxs(ctx, fcnxs_symbol, fcnxs_args_from, fcnxs_auxs_from): + """ use zero initialization for better convergence, because it tends to oputut 0, + and the label 0 stands for background, which may occupy most size of one image. + """ fcnxs_args = fcnxs_args_from.copy() fcnxs_auxs = fcnxs_auxs_from.copy() for k,v in fcnxs_args.items(): diff --git a/example/fcn-xs/run_fcnxs.sh b/example/fcn-xs/run_fcnxs.sh index 8dca9b8231c5..926f3f840415 100755 --- a/example/fcn-xs/run_fcnxs.sh +++ b/example/fcn-xs/run_fcnxs.sh @@ -2,10 +2,10 @@ python -u fcn_xs.py --model=fcn32s --prefix=VGG_FC_ILSVRC_16_layers \ --epoch=74 --init-type=vgg16 -# # train fcn-16s model -# python -u fcn_xs.py --model=fcn16s --prefix=FCN32s_VGG16 \ -# --epoch=31 --init-type=fcnxs +## train fcn-16s model +#python -u fcn_xs.py --model=fcn16s --prefix=FCN32s_VGG16 \ + #--epoch=31 --init-type=fcnxs -# # train fcn-8s model -# python -u fcn_xs.py --model=fcn8s --prefix=FCN16s_VGG16 \ -# --epoch=27 --init-type=fcnxs +# train fcn-8s model +#python -u fcn_xs.py --model=fcn8s --prefix=FCN16s_VGG16 \ + #--epoch=27 --init-type=fcnxs diff --git a/example/fcn-xs/symbol_fcnxs.py b/example/fcn-xs/symbol_fcnxs.py index 0ef6bc363c8b..ab283fa13f50 100644 --- a/example/fcn-xs/symbol_fcnxs.py +++ b/example/fcn-xs/symbol_fcnxs.py @@ -139,7 +139,7 @@ def fcnxs_score(input, crop, offset, kernel=(64,64), stride=(32,32), numclass=21 # score out bigscore = mx.symbol.Deconvolution(data=input, kernel=kernel, stride=stride, num_filter=numclass, workspace=workspace_default, name="bigscore") - upscore = mx.symbol.Crop(data=bigscore, crop_like=crop, offset=offset, name="upscore") + upscore = mx.symbol.Crop(*[bigscore, crop], offset=offset, name="upscore") softmax = mx.symbol.SoftmaxOutput(data=upscore, multi_output=True, use_ignore=True, ignore_label=255, name="softmax") return softmax @@ -161,9 +161,8 @@ def get_fcn16s_symbol(numclass=21, workspace_default=1024): workspace=workspace_default, name="score2") # 2X score_pool4 = mx.symbol.Convolution(data=pool4, kernel=(1, 1), num_filter=numclass, workspace=workspace_default, name="score_pool4") - score_pool4c = mx.symbol.Crop(data=score_pool4, crop_like=score2, - offset=offset()["score_pool4c"], name="score_pool4c") - score_fused = mx.symbol.ElementWiseSum(*[score2, score_pool4c], name='score_fused') + score_pool4c = mx.symbol.Crop(*[score_pool4, score2], offset=offset()["score_pool4c"], name="score_pool4c") + score_fused = score2 + score_pool4c softmax = fcnxs_score(score_fused, data, offset()["fcn16s_upscore"], (32, 32), (16, 16), numclass, workspace_default) return softmax @@ -177,16 +176,14 @@ def get_fcn8s_symbol(numclass=21, workspace_default=1024): workspace=workspace_default, name="score2") # 2X score_pool4 = mx.symbol.Convolution(data=pool4, kernel=(1, 1), num_filter=numclass, workspace=workspace_default, name="score_pool4") - score_pool4c = mx.symbol.Crop(data=score_pool4, crop_like=score2, - offset=offset()["score_pool4c"], name="score_pool4c") - score_fused = mx.symbol.ElementWiseSum(*[score2, score_pool4c], name='score_fused') + score_pool4c = mx.symbol.Crop(*[score_pool4, score2], offset=offset()["score_pool4c"], name="score_pool4c") + score_fused = score2 + score_pool4c # score 4X score4 = mx.symbol.Deconvolution(data=score_fused, kernel=(4, 4), stride=(2, 2),num_filter=numclass, workspace=workspace_default, name="score4") # 4X score_pool3 = mx.symbol.Convolution(data=pool3, kernel=(1, 1), num_filter=numclass, workspace=workspace_default, name="score_pool3") - score_pool3c = mx.symbol.Crop(data=score_pool3, crop_like=score4, - offset=offset()["score_pool3c"], name="score_pool3c") - score_final = mx.symbol.ElementWiseSum(*[score4, score_pool3c], name='score_final') + score_pool3c = mx.symbol.Crop(*[score_pool3, score4], offset=offset()["score_pool3c"], name="score_pool3c") + score_final = score4 + score_pool3c softmax = fcnxs_score(score_final, data, offset()["fcn8s_upscore"], (16, 16), (8, 8), numclass, workspace_default) return softmax diff --git a/ps-lite b/ps-lite index b1da4b6e0f9e..d175ec2393c6 160000 --- a/ps-lite +++ b/ps-lite @@ -1 +1 @@ -Subproject commit b1da4b6e0f9e387ee30d2d02a063944986ff0cbd +Subproject commit d175ec2393c6ab00d5d0a143b42ee6dc6efb7038 diff --git a/src/operator/crop-inl.h b/src/operator/crop-inl.h index 750ab833da3c..f28e93fbe330 100644 --- a/src/operator/crop-inl.h +++ b/src/operator/crop-inl.h @@ -25,12 +25,20 @@ enum CropOpOutputs {kOut}; } // namespace crop_enum struct CropParam : public dmlc::Parameter { + int num_args; TShape offset; + TShape h_w; bool center_crop; DMLC_DECLARE_PARAMETER(CropParam) { + DMLC_DECLARE_FIELD(num_args).set_range(1, 3) + .describe("Number of inputs for crop, if equals one, then we will use the h_w" + "for crop heihgt and width, else if equals two, then we will use the height" + "and width of the second input symbol, we name crop_like here"); int shape[] = {0, 0}; DMLC_DECLARE_FIELD(offset).set_default(TShape(shape, shape + 2)) .describe("corp offset coordinate: (y, x)"); + DMLC_DECLARE_FIELD(h_w).set_default(TShape(shape, shape + 2)) + .describe("corp height and weight: (h, w)"); DMLC_DECLARE_FIELD(center_crop).set_default(false) .describe("If set to true, then it will use be the center_crop," "or it will crop using the shape of crop_like"); @@ -130,25 +138,47 @@ class CropProp : public OperatorProperty { } std::vector ListArguments() const override { - return {"data", "crop_like"}; + // return {"data", "crop_like"}; + std::vector ret; + for (int i = 0; i < param_.num_args; ++i) { + ret.push_back(std::string("arg") + static_cast('0' + i)); + } + return ret; } bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { using namespace mshadow; - CHECK_EQ(in_shape->size(), 2) << "Input:[data, crop_like]"; + CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); TShape data_shape = in_shape->at(crop_enum::kData); if (data_shape.ndim() == 0) return false; CHECK_EQ(data_shape.ndim(), 4) << \ "Input data should be 4D in batch-num_filter-y-x"; - TShape crop_shape = in_shape->at(crop_enum::kCropLike); - if (crop_shape.ndim() == 0) return false; - CHECK_EQ(crop_shape.ndim(), 4) << \ - "Input crop_like should be 4D in batch-num_filter/batch-num_channel-y-x"; + std::vector crop_shape; + if (param_.num_args == 1) { + std::cout << "ok1" << std::endl; + CHECK_GE(static_cast(param_.h_w[0]), 1) << + "the crop height(h_w[0]) should be larger than 1"; + CHECK_LE(static_cast(param_.h_w[0]), static_cast(data_shape[2])) << + "the crop height(h_w[0]) should be less than the input data's height"; + CHECK_GE(static_cast(param_.h_w[1]), 1) << + "the crop width(h_w[1]) should be larger than 1"; + CHECK_LE(static_cast(param_.h_w[1]), static_cast(data_shape[3])) << + "the crop width(h_w[1]) should be less than the input data's width"; + crop_shape.push_back(param_.h_w[0]); + crop_shape.push_back(param_.h_w[1]); + } else if (param_.num_args == 2) { + TShape crop_like_shape = in_shape->at(crop_enum::kCropLike); + crop_shape.push_back(crop_like_shape[2]); + crop_shape.push_back(crop_like_shape[3]); + } + if (crop_shape.size() == 0) return false; + CHECK_EQ(crop_shape.size(), 2) << \ + "Input crop_like should be 2D in height-width"; out_shape->clear(); - data_shape[2] = crop_shape[2]; - data_shape[3] = crop_shape[3]; + data_shape[2] = crop_shape[0]; + data_shape[3] = crop_shape[1]; out_shape->push_back(data_shape); return true; } diff --git a/src/operator/crop.cc b/src/operator/crop.cc index 5a3315c24d63..681d192237ee 100644 --- a/src/operator/crop.cc +++ b/src/operator/crop.cc @@ -21,9 +21,10 @@ Operator* CropProp::CreateOperator(Context ctx) const { DMLC_REGISTER_PARAMETER(CropParam); MXNET_REGISTER_OP_PROPERTY(Crop, CropProp) -.add_argument("data", "Symbol", "Input data to the CropOp.") -.add_argument("crop_like", "Symbol", "crop_like data to the CropOp.") +.describe("Crop the 2th and 3th dim of input data, with the corresponding size of crop_like.") +// .add_argument("data", "Symbol", "Input data to the CropOp.") +// .add_argument("crop_like", "Symbol", "crop_like data to the CropOp.") .add_arguments(CropParam::__FIELDS__()) -.describe("Crop the 2th and 3th dim of input data, with the corresponding size of crop_like."); +.set_key_var_num_args("num_args"); } // namespace op } // namespace mxnet From d91b5b7567b81d513e74d6e0fdd310ffe701cd9d Mon Sep 17 00:00:00 2001 From: winsty Date: Wed, 30 Dec 2015 13:02:13 +0800 Subject: [PATCH 10/32] fix HSV --- src/io/image_augmenter.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/io/image_augmenter.h b/src/io/image_augmenter.h index 51586887ede2..5c7c7c2359d5 100644 --- a/src/io/image_augmenter.h +++ b/src/io/image_augmenter.h @@ -83,11 +83,11 @@ struct ImageAugmentParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(min_img_size).set_default(0.0f) .describe("Augmentation Param: Minimum image size after resizing."); DMLC_DECLARE_FIELD(random_h).set_default(0) - .describe("Augmentation Param: Maximum value of H channel in HSV color space."); + .describe("Augmentation Param: Maximum value of H channel in HSL color space."); DMLC_DECLARE_FIELD(random_s).set_default(0) - .describe("Augmentation Param: Maximum value of S channel in HSV color space."); + .describe("Augmentation Param: Maximum value of S channel in HSL color space."); DMLC_DECLARE_FIELD(random_l).set_default(0) - .describe("Augmentation Param: Maximum value of L channel in HSV color space."); + .describe("Augmentation Param: Maximum value of L channel in HSL color space."); DMLC_DECLARE_FIELD(rotate).set_default(-1.0f) .describe("Augmentation Param: Rotate angle."); DMLC_DECLARE_FIELD(fill_value).set_default(255) From 5830796eaae8e0b43aff352f59bd995b6cc6ed57 Mon Sep 17 00:00:00 2001 From: wuwei Date: Wed, 30 Dec 2015 12:40:56 +0800 Subject: [PATCH 11/32] update pr review2 fix description fix baiduyun address remove debug info --- example/fcn-xs/README.md | 6 +++--- src/operator/crop-inl.h | 1 - src/operator/crop.cc | 5 ++--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/example/fcn-xs/README.md b/example/fcn-xs/README.md index e970eb4a2414..a902fcdee7ac 100644 --- a/example/fcn-xs/README.md +++ b/example/fcn-xs/README.md @@ -17,10 +17,10 @@ the training image number is only : 2027, and the Validation image number is: 46 ## How to train fcn-xs in mxnet #### step1: download the vgg16fc model and experiment data -* vgg16fc model : you can download the ```VGG_FC_ILSVRC_16_layers-symbol.json``` and ```VGG_FC_ILSVRC_16_layers-0074.params``` from [yun.baidu](http://pan.baidu.com/s/1gerce1H). +* vgg16fc model : you can download the ```VGG_FC_ILSVRC_16_layers-symbol.json``` and ```VGG_FC_ILSVRC_16_layers-0074.params``` from [yun.baidu](http://pan.baidu.com/s/1bgz4PC). this is the fully convolution style of the origin [VGG_ILSVRC_16_layers.caffemodel](http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel), and the corresponding [VGG_ILSVRC_16_layers_deploy.prototxt](https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-vgg_ilsvrc_16_layers_deploy-prototxt), the vgg16 model has [license](http://creativecommons.org/licenses/by-nc/4.0/) for non-commercial use only. -* experiment data : you can download the ```VOC2012.rar``` from [yun.baidu](http://pan.baidu.com/s/1jGlOvno), and Extract it. the file/folder will be like: +* experiment data : you can download the ```VOC2012.rar``` from [yun.baidu](http://pan.baidu.com/s/1bgz4PC), and Extract it. the file/folder will be like: ```JPEGImages folder```, ```SegmentationClass folder```, ```train.lst```, ```val.lst```, ```test.lst``` #### step2: train fcn-xs model @@ -47,7 +47,7 @@ INFO:root:Epoch[0] Batch [350] Speed: 1.12 samples/sec Train-accuracy=0.912080 ``` ## Using the pre-trained model for image segmentation -* similarly, you should firstly download the pre-trained model from [yun.baidu](http://pan.baidu.com/s/1gerce1H), the symbol and model file is ```FCN8s_VGG16-symbol.json```, ```FCN8s_VGG16-0019.params``` +* similarly, you should firstly download the pre-trained model from [yun.baidu](http://pan.baidu.com/s/1bgz4PC), the symbol and model file is ```FCN8s_VGG16-symbol.json```, ```FCN8s_VGG16-0019.params``` * then put the image in your directory for segmentation, and change the ```img = YOUR_IMAGE_NAME``` in ```image_segmentaion.py``` * lastly, use ```image_segmentaion.py``` to segmentation one image by run in shell ```python image_segmentaion.py```, then you will get the segmentation image like the sample result above. diff --git a/src/operator/crop-inl.h b/src/operator/crop-inl.h index f28e93fbe330..98a081fed5b3 100644 --- a/src/operator/crop-inl.h +++ b/src/operator/crop-inl.h @@ -157,7 +157,6 @@ class CropProp : public OperatorProperty { "Input data should be 4D in batch-num_filter-y-x"; std::vector crop_shape; if (param_.num_args == 1) { - std::cout << "ok1" << std::endl; CHECK_GE(static_cast(param_.h_w[0]), 1) << "the crop height(h_w[0]) should be larger than 1"; CHECK_LE(static_cast(param_.h_w[0]), static_cast(data_shape[2])) << diff --git a/src/operator/crop.cc b/src/operator/crop.cc index 681d192237ee..2d46a64df78e 100644 --- a/src/operator/crop.cc +++ b/src/operator/crop.cc @@ -21,9 +21,8 @@ Operator* CropProp::CreateOperator(Context ctx) const { DMLC_REGISTER_PARAMETER(CropParam); MXNET_REGISTER_OP_PROPERTY(Crop, CropProp) -.describe("Crop the 2th and 3th dim of input data, with the corresponding size of crop_like.") -// .add_argument("data", "Symbol", "Input data to the CropOp.") -// .add_argument("crop_like", "Symbol", "crop_like data to the CropOp.") +.describe("Crop the 2th and 3th dim of input data, with the corresponding size of w_h or" +"with widht and height of the second input symbol") .add_arguments(CropParam::__FIELDS__()) .set_key_var_num_args("num_args"); } // namespace op From aa82d39f961a35eb8230ccd59653024e450f9d56 Mon Sep 17 00:00:00 2001 From: Rodrigo Castro Date: Wed, 30 Dec 2015 21:18:35 +0000 Subject: [PATCH 12/32] Fix issue #1116 - the learning rate is too big, leading to NRR=Infinity and Prep=Infinity - the optimizer was changed from 'sgd' to 'rmsprop', but the hyperparameters were not updated: there is no 'momentum' hyperparameter in rmsprop. - the 'input.txt' file is not automatically downloaded when running the ./get_ptb_data.sh script - the script get_ptb_data.sh script should use [#!/usr/bin/env](http://unix.stackexchange.com/questions/29608/why-is-it-better-to-use-usr-bin-env-name-instead-of-path-to-name-as-my) --- example/rnn/char_lstm.ipynb | 10 +++++----- example/rnn/get_ptb_data.sh | 5 +++-- 2 files changed, 8 insertions(+), 7 deletions(-) mode change 100644 => 100755 example/rnn/get_ptb_data.sh diff --git a/example/rnn/char_lstm.ipynb b/example/rnn/char_lstm.ipynb index 72ba3f18dc41..de152e4c9865 100644 --- a/example/rnn/char_lstm.ipynb +++ b/example/rnn/char_lstm.ipynb @@ -8,7 +8,9 @@ "This example aims to show how to use lstm to build a char level language model, and generate text from it. \n", "We use a tiny shakespeare text for demo purpose. \n", "\n", - "Data can be found at [https://github.com/dmlc/web-data/tree/master/mxnet/tinyshakespeare](https://github.com/dmlc/web-data/tree/master/mxnet/tinyshakespeare). " + "Data can be found at [https://github.com/dmlc/web-data/tree/master/mxnet/tinyshakespeare](https://github.com/dmlc/web-data/tree/master/mxnet/tinyshakespeare). ", + "\n", + "If running for the first time, download the data by running the following commands: cd example/rnn ; ./get_ptb_data.sh" ] }, { @@ -48,9 +50,8 @@ "num_embed = 256\n", "num_lstm_layer = 2\n", "num_round = 21\n", - "learning_rate= 1\n", + "learning_rate= 0.01\n", "wd=0.00001\n", - "momentum=0.0\n", "clip_gradient=1\n", "update_period = 1\n" ] @@ -138,7 +139,7 @@ } ], "source": [ - "X, dic, lookup_table = make_batch(\"./input.txt\", batch_size=batch_size, seq_lenth=seq_len)\n", + "X, dic, lookup_table = make_batch(\"./data/input.txt\", batch_size=batch_size, seq_lenth=seq_len)\n", "vocab = len(dic)" ] }, @@ -443,7 +444,6 @@ " update_period=update_period,\n", " learning_rate=learning_rate,\n", " wd=wd,\n", - " momentum=momentum,\n", " clip_gradient=clip_gradient)" ] }, diff --git a/example/rnn/get_ptb_data.sh b/example/rnn/get_ptb_data.sh old mode 100644 new mode 100755 index 2b517f4ebc4d..1ec009aa2f99 --- a/example/rnn/get_ptb_data.sh +++ b/example/rnn/get_ptb_data.sh @@ -1,4 +1,4 @@ -#!/bin/env bash +#!/usr/bin/env bash RNN_DIR=$(cd `dirname $0`; pwd) DATA_DIR="${RNN_DIR}/data/" @@ -8,6 +8,7 @@ if [[ ! -d "${DATA_DIR}" ]]; then mkdir -p ${DATA_DIR} fi -wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.train.txt; +wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.train.txt; wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.valid.txt; wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.test.txt; +wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt; From bc48a4318835b7264308c2876b2686ced83b83ca Mon Sep 17 00:00:00 2001 From: wuwei Date: Thu, 31 Dec 2015 11:18:03 +0800 Subject: [PATCH 13/32] remove personal path --- example/fcn-xs/fcn_xs.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/example/fcn-xs/fcn_xs.py b/example/fcn-xs/fcn_xs.py index 01344a7b123f..85961d92c694 100644 --- a/example/fcn-xs/fcn_xs.py +++ b/example/fcn-xs/fcn_xs.py @@ -1,11 +1,6 @@ # pylint: skip-file import sys, os import argparse -# mxnet_train = "/home/work/wuwei/tools/mxnet/lib/python2.7/site-packages/mxnet-0.5.0-py2.7.egg" -mxnet_train = "/home/work/wuwei/.local/lib/python2.7/site-packages/mxnet-0.5.0-py2.7.egg" -if mxnet_train in sys.path: - sys.path.remove(mxnet_train) -sys.path.insert(0, mxnet_train) import mxnet as mx import numpy as np import logging @@ -43,7 +38,6 @@ def main(): val_dataiter = FileIter( root_dir = "./VOC2012", flist_name = "val.lst", - # cut_off_size = 400, rgb_mean = (123.68, 116.779, 103.939), ) model = Solver( From 2de3d2887706a2c85f4d40383deb9a6e0a94822a Mon Sep 17 00:00:00 2001 From: shijianping Date: Thu, 31 Dec 2015 22:22:37 +0800 Subject: [PATCH 14/32] add dilate for convolution --- mshadow | 2 +- src/operator/convolution-inl.h | 31 +++++++++++++++++++++++-------- src/operator/deconvolution-inl.h | 27 +++++++++++++++++++++------ 3 files changed, 45 insertions(+), 15 deletions(-) diff --git a/mshadow b/mshadow index 47521c6f8e0d..8b6d2c8b148e 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit 47521c6f8e0d62a0224bc5bb19b60cc6a0d6a95c +Subproject commit 8b6d2c8b148ed47a310a5a5ab5798bdd058f4f35 diff --git a/src/operator/convolution-inl.h b/src/operator/convolution-inl.h index 0a2313592e04..4234226bbaf4 100644 --- a/src/operator/convolution-inl.h +++ b/src/operator/convolution-inl.h @@ -30,6 +30,7 @@ enum ConvolutionOpResource {kTempSpace}; struct ConvolutionParam : public dmlc::Parameter { TShape kernel; TShape stride; + TShape dilate; TShape pad; uint32_t num_filter; uint32_t num_group; @@ -40,6 +41,8 @@ struct ConvolutionParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(kernel).describe("convolution kernel size: (y, x)"); DMLC_DECLARE_FIELD(stride).set_default(TShape(shape, shape + 2)) .describe("convolution stride: (y, x)"); + DMLC_DECLARE_FIELD(dilate).set_default(TShape(shape, shape + 2)) + .describe("convolution dilate: (y, x)"); shape[0] = shape[1] = 0; DMLC_DECLARE_FIELD(pad).set_default(TShape(shape, shape + 2)) .describe("pad for convolution: (y, x)"); @@ -105,14 +108,18 @@ class ConvolutionOp : public Operator { param_.kernel[0], param_.kernel[1], param_.stride[0], - param_.stride[1]); + param_.stride[1], + param_.dilate[0], + param_.dilate[1]); } else { temp_col = unpack_patch2col(pad(data.Slice(i, i + step), param_.pad[0], param_.pad[1]), param_.kernel[0], param_.kernel[1], param_.stride[0], - param_.stride[1]); + param_.stride[1], + param_.dilate[0], + param_.dilate[1]); } const index_t gstride = temp_col.size(0) / param_.num_group; for (uint32_t gid = 0; gid < param_.num_group; ++gid) { @@ -181,13 +188,17 @@ class ConvolutionOp : public Operator { param_.kernel[0], param_.kernel[1], param_.stride[0], - param_.stride[1]); + param_.stride[1], + param_.dilate[0], + param_.dilate[1]); } else { temp_col = unpack_patch2col(pad(data.Slice(i, i + step), param_.pad[0], param_.pad[1]), param_.kernel[0], param_.kernel[1], param_.stride[0], - param_.stride[1]); + param_.stride[1], + param_.dilate[0], + param_.dilate[1]); } const index_t gstride = temp_col.size(0) / param_.num_group; for (uint32_t gid = 0; gid < param_.num_group; ++gid) { @@ -209,7 +220,8 @@ class ConvolutionOp : public Operator { data.Slice(i, i + step).shape_, param_.kernel[0], param_.kernel[1], - param_.stride[0]); + param_.stride[0], + param_.dilate[0]); } else { Shape<4> pshape = data.Slice(i, i + step).shape_; pshape[2] += 2 * param_.pad[0]; @@ -218,7 +230,8 @@ class ConvolutionOp : public Operator { pshape, param_.kernel[0], param_.kernel[1], - param_.stride[0]), + param_.stride[0], + param_.dilate[0]), gdata[i][0].shape_); } } @@ -318,11 +331,13 @@ class ConvolutionProp : public OperatorProperty { << "incorrect kernel size: " << param_.kernel; CHECK_GE(param_.stride.Size(), 0) \ << "incorrect stride size: " << param_.stride; + CHECK_GE(param_.dilate.Size(), 0) \ + << "incorrect dilate size: " << param_.dilate; CHECK(ksize_x <= dshape[3] && ksize_y <= dshape[2]) << "kernel size exceed input"; (*out_shape)[conv::kOut][1] = param_.num_filter; - (*out_shape)[conv::kOut][2] = (dshape[2] + 2 * param_.pad[0] - ksize_y) / param_.stride[0] + 1; - (*out_shape)[conv::kOut][3] = (dshape[3] + 2 * param_.pad[1] - ksize_x) / param_.stride[1] + 1; + (*out_shape)[conv::kOut][2] = (dshape[2] + 2 * param_.pad[0] - (param_.dilate[0] == 1 ? ksize_y : ksize_y * param_.dilate[0] - 1)) / param_.stride[0] + 1; + (*out_shape)[conv::kOut][3] = (dshape[3] + 2 * param_.pad[1] - (param_.dilate[1] == 1 ? ksize_x : ksize_x * param_.dilate[1] - 1)) / param_.stride[1] + 1; return true; } diff --git a/src/operator/deconvolution-inl.h b/src/operator/deconvolution-inl.h index 9733dcb02dc1..0fa0e6679c06 100644 --- a/src/operator/deconvolution-inl.h +++ b/src/operator/deconvolution-inl.h @@ -30,6 +30,7 @@ namespace deconv { struct DeconvolutionParam : public dmlc::Parameter { TShape kernel; TShape stride; + TShape dilate; TShape pad; uint32_t num_filter; uint32_t num_group; @@ -40,6 +41,8 @@ struct DeconvolutionParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(kernel).describe("deconvolution kernel size: (y, x)"); DMLC_DECLARE_FIELD(stride).set_default(TShape(shape, shape + 2)) .describe("deconvolution stride: (y, x)"); + DMLC_DECLARE_FIELD(dilate).set_default(TShape(shape, shape + 2)) + .describe("deconvolution dilate: (y, x)"); shape[0] = shape[1] = 0; DMLC_DECLARE_FIELD(pad).set_default(TShape(shape, shape + 2)) .describe("pad for deconvolution: (y, x)"); @@ -104,14 +107,18 @@ class DeconvolutionOp : public Operator { param_.kernel[0], param_.kernel[1], param_.stride[0], - param_.stride[1]); + param_.stride[1], + param_.dilate[0], + param_.dilate[1]); } else { temp_col = unpack_patch2col(pad(out.Slice(i, i + step), param_.pad[0], param_.pad[1]), param_.kernel[0], param_.kernel[1], param_.stride[0], - param_.stride[1]); + param_.stride[1], + param_.dilate[0], + param_.dilate[1]); } const index_t gstride = temp_col.size(0) / param_.num_group; for (uint32_t gid = 0; gid < param_.num_group; ++gid) { @@ -124,7 +131,8 @@ class DeconvolutionOp : public Operator { out.Slice(i, i + step).shape_, param_.kernel[0], param_.kernel[1], - param_.stride[0]); + param_.stride[0], + param_.dilate[0]); } else { Shape<4> pshape = out.Slice(i, i + step).shape_; pshape[2] += 2 * param_.pad[0]; @@ -133,7 +141,8 @@ class DeconvolutionOp : public Operator { pshape, param_.kernel[0], param_.kernel[1], - param_.stride[0]), + param_.stride[0], + param_.dilate[0]), out[i][0].shape_); } } @@ -192,13 +201,17 @@ class DeconvolutionOp : public Operator { param_.kernel[0], param_.kernel[1], param_.stride[0], - param_.stride[1]); + param_.stride[1], + param_.dilate[0], + param_.dilate[1]); } else { temp_col = unpack_patch2col(pad(grad.Slice(i, i + step), param_.pad[0], param_.pad[1]), param_.kernel[0], param_.kernel[1], param_.stride[0], - param_.stride[1]); + param_.stride[1], + param_.dilate[0], + param_.dilate[1]); } const index_t gstride = temp_col.size(0) / param_.num_group; for (uint32_t gid = 0; gid < param_.num_group; ++gid) { @@ -316,6 +329,8 @@ class DeconvolutionProp : public OperatorProperty { << "incorrect kernel size: " << param_.kernel; CHECK_GE(param_.stride.Size(), 0) \ << "incorrect stride size: " << param_.stride; + CHECK_EQ(param_.dilate.Size(), 1) \ + << "Dilate not supported in deconvolution, incorrect stride size: " << param_.stride; (*out_shape)[deconv::kOut][1] = param_.num_filter; (*out_shape)[deconv::kOut][2] = param_.stride[0] * (dshape[2] - 1) + ksize_y - 2 * param_.pad[0]; From afdfac3225815ea5d14914fd953092f0ac058986 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 28 Dec 2015 23:23:31 +0800 Subject: [PATCH 15/32] make Executor::SetMonitorCallback take std::function as parameter --- include/mxnet/c_api.h | 6 +++--- include/mxnet/symbolic.h | 6 +++++- src/c_api/c_api.cc | 8 ++++++-- src/symbol/graph_executor.h | 3 ++- 4 files changed, 16 insertions(+), 7 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index f070ac8a8f80..d1c54d256f31 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -55,8 +55,8 @@ typedef void *OptimizerCreator; /*! \brief handle to Optimizer*/ typedef void *OptimizerHandle; -MXNET_EXTERN_C typedef void (*ExcecutorMonitorCallback)(const char*, - NDArrayHandle); +MXNET_EXTERN_C typedef void (*ExecutorMonitorCallback)(const char*, + NDArrayHandle); MXNET_EXTERN_C { struct NativeOpInfo { @@ -730,7 +730,7 @@ MXNET_DLL int MXExecutorBindX(SymbolHandle symbol_handle, * \brief set a call back to notify the completion of operation */ MXNET_DLL int MXExecutorSetMonitorCallback(ExecutorHandle handle, - ExcecutorMonitorCallback callback); + ExecutorMonitorCallback callback); //-------------------------------------------- // Part 5: IO Interface //-------------------------------------------- diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index 3876c211cfcc..e61e886aab32 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -320,10 +320,14 @@ class Executor { const std::vector &arg_grad_store, const std::vector &grad_req_type, const std::vector &aux_states); + /*! + * \brief the prototype of user-defined monitor callback + */ + typedef std::function MonitorCallback; /*! * \brief Install a callback to notify the completion of operation. */ - virtual void SetMonitorCallback(ExcecutorMonitorCallback callback) {} + virtual void SetMonitorCallback(const MonitorCallback& callback) {} }; // class operator } // namespace mxnet #endif // MXNET_SYMBOLIC_H_ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 3deea52f9e9d..864c2524cf0e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -824,10 +824,14 @@ int MXExecutorBindX(SymbolHandle symbol_handle, } int MXExecutorSetMonitorCallback(ExecutorHandle handle, - ExcecutorMonitorCallback callback) { + ExecutorMonitorCallback callback) { API_BEGIN(); + std::function clbk + = [callback](const char *name, void* handle) { + callback(name, handle); + }; Executor *exec = static_cast(handle); - exec->SetMonitorCallback(callback); + exec->SetMonitorCallback(clbk); API_END(); } diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h index 1ec47294a124..ba218e330c3f 100644 --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -32,7 +32,8 @@ class GraphExecutor : public Executor { } void Print(std::ostream &os) const override; // NOLINT(*) // install callback - void SetMonitorCallback(ExcecutorMonitorCallback callback) { + void SetMonitorCallback(const MonitorCallback& callback) { + CHECK(callback) << "invalid callback"; monitor_callback_ = callback; } // implement Executor::Bind, only call it once. From bd84b2c71a9c469da37b87787db36a9d7998c768 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Mon, 28 Dec 2015 23:26:57 +0800 Subject: [PATCH 16/32] example/autoencoder fix optimizer param in layerwise_pretrain & finetune --- example/autoencoder/autoencoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/example/autoencoder/autoencoder.py b/example/autoencoder/autoencoder.py index 9d3fd253c947..8fcd663c2b33 100644 --- a/example/autoencoder/autoencoder.py +++ b/example/autoencoder/autoencoder.py @@ -134,7 +134,7 @@ def make_decoder(self, feature, dims, sparseness_penalty=None, dropout=None, int def layerwise_pretrain(self, X, batch_size, n_iter, optimizer, l_rate, decay, lr_scheduler=None): def l2_norm(label, pred): return np.mean(np.square(label-pred))/2.0 - solver = Solver('sgd', momentum=0.9, wd=decay, learning_rate=l_rate, lr_scheduler=lr_scheduler) + solver = Solver(optimizer, momentum=0.9, wd=decay, learning_rate=l_rate, lr_scheduler=lr_scheduler) solver.set_metric(mx.metric.CustomMetric(l2_norm)) solver.set_monitor(Monitor(1000)) data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=True, @@ -154,7 +154,7 @@ def l2_norm(label, pred): def finetune(self, X, batch_size, n_iter, optimizer, l_rate, decay, lr_scheduler=None): def l2_norm(label, pred): return np.mean(np.square(label-pred))/2.0 - solver = Solver('sgd', momentum=0.9, wd=decay, learning_rate=l_rate, lr_scheduler=lr_scheduler) + solver = Solver(optimizer, momentum=0.9, wd=decay, learning_rate=l_rate, lr_scheduler=lr_scheduler) solver.set_metric(mx.metric.CustomMetric(l2_norm)) solver.set_monitor(Monitor(1000)) data_iter = mx.io.NDArrayIter({'data': X}, batch_size=batch_size, shuffle=True, From def03c2f64d7963a174eef09fd56b33590816954 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 30 Dec 2015 00:45:52 +0800 Subject: [PATCH 17/32] make MXExecutorSetMonitorCallback take an additional pointer as argument, allow frontend packages bind their closure objects fix python lint make MXExecutorSetMonitorCallback take an additional pointer as argument, allow frontend packages bind their closure objects --- include/mxnet/c_api.h | 6 ++++-- python/mxnet/executor.py | 14 +++++++++++--- src/c_api/c_api.cc | 9 ++++++--- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index d1c54d256f31..6087fde01ea5 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -56,7 +56,8 @@ typedef void *OptimizerCreator; typedef void *OptimizerHandle; MXNET_EXTERN_C typedef void (*ExecutorMonitorCallback)(const char*, - NDArrayHandle); + NDArrayHandle, + void *); MXNET_EXTERN_C { struct NativeOpInfo { @@ -730,7 +731,8 @@ MXNET_DLL int MXExecutorBindX(SymbolHandle symbol_handle, * \brief set a call back to notify the completion of operation */ MXNET_DLL int MXExecutorSetMonitorCallback(ExecutorHandle handle, - ExecutorMonitorCallback callback); + ExecutorMonitorCallback callback, + void* callback_handle); //-------------------------------------------- // Part 5: IO Interface //-------------------------------------------- diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 3e357fb18974..4b44272cab7f 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -12,6 +12,13 @@ from .context import cpu import logging +def _monitor_callback_wrapper(callback): + """ a wrapper for the user-defined handle """ + def callback_handle(name, array, _): + """ ctypes function """ + callback(name, array) + return callback_handle + class Executor(object): """ Executor is the actual executing object of MXNet.""" def __init__(self, handle, symbol): @@ -129,11 +136,12 @@ def set_monitor_callback(self, callback): callback : function Takes a string and an NDArrayHandle. """ - cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, NDArrayHandle) - self._monitor_callback = cb_type(callback) + cb_type = ctypes.CFUNCTYPE(None, ctypes.c_char_p, NDArrayHandle, ctypes.c_void_p) + self._monitor_callback = cb_type(_monitor_callback_wrapper(callback)) check_call(_LIB.MXExecutorSetMonitorCallback( self.handle, - self._monitor_callback)) + self._monitor_callback, + None)) @property def arg_dict(self): diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 864c2524cf0e..b8a03b1c276f 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -824,11 +824,14 @@ int MXExecutorBindX(SymbolHandle symbol_handle, } int MXExecutorSetMonitorCallback(ExecutorHandle handle, - ExecutorMonitorCallback callback) { + ExecutorMonitorCallback callback, + void* callback_handle) { API_BEGIN(); + ExecutorMonitorCallback callback_temp = callback; + void* callback_handle_temp = callback_handle; std::function clbk - = [callback](const char *name, void* handle) { - callback(name, handle); + = [callback_temp, callback_handle_temp](const char *name, void* handle) { + callback_temp(name, handle, callback_handle_temp); }; Executor *exec = static_cast(handle); exec->SetMonitorCallback(clbk); From 0b94a6a598d238e94127f0c0fa955db2c9042383 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Wed, 30 Dec 2015 02:44:14 -0700 Subject: [PATCH 18/32] fix wd in sgd --- dmlc-core | 2 +- ps-lite | 2 +- python/mxnet/model.py | 3 ++- python/mxnet/optimizer.py | 18 +++++++++++++++--- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/dmlc-core b/dmlc-core index a9b3320d2c6b..98879773f062 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit a9b3320d2c6b29506139784b877142c9ee78caaf +Subproject commit 98879773f062e1c5b8a380e002c25672f2b48b13 diff --git a/ps-lite b/ps-lite index d175ec2393c6..1b0d1df2d297 160000 --- a/ps-lite +++ b/ps-lite @@ -1 +1 @@ -Subproject commit d175ec2393c6ab00d5d0a143b42ee6dc6efb7038 +Subproject commit 1b0d1df2d2971414e8e2bafd2bfe65ed6965baaf diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 3200c7b46233..39a2c584be96 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -621,6 +621,7 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', arg_names, param_names, aux_names = \ self._init_params(dict(data.provide_data+data.provide_label)) + self.kwargs["arg_names"] = arg_names # setup metric if not isinstance(eval_metric, metric.EvalMetric): @@ -708,7 +709,7 @@ def load(prefix, epoch, ctx=None, **kwargs): @staticmethod def create(symbol, X, y=None, ctx=None, - num_epoch=None, epoch_size=None, optimizer='ccsgd', initializer=Uniform(0.01), + num_epoch=None, epoch_size=None, optimizer='sgd', initializer=Uniform(0.01), eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', logger=None, work_load_list=None, **kwargs): diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 738e39752edd..b321424954ff 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -145,19 +145,28 @@ class SGD(Optimizer): clip_gradient : float, optional clip gradient in range [-clip_gradient, clip_gradient] + + arg_names : list(str), optional + special treat weight decay in parameter ends with bias, gamma, and beta """ def __init__(self, learning_rate=0.01, momentum=0.0, wd=0.0001, rescale_grad=1, clip_gradient=None, - lr_scheduler=None): + lr_scheduler=None, arg_names=None): super(SGD, self).__init__(rescale_grad) self.lr = learning_rate self.momentum = momentum self.wd = wd self.clip_gradient = clip_gradient self.lr_scheduler = lr_scheduler + self.weight_set = set([]) if lr_scheduler is not None: self.lr_scheduler.base_lr = learning_rate + if arg_names is not None: + for idx, name in enumerate(arg_names): + if name.endswith("weight"): + self.weight_set.add(idx) + def create_state(self, index, weight): """Create additional optimizer state such as momentum. @@ -189,7 +198,6 @@ def update(self, index, weight, grad, state): state : NDArray or other objects returned by init_state The auxiliary state used in optimization. """ - # TODO(bing) implement wd_bias, wd_gamma, wd_beta assert(isinstance(weight, NDArray)) assert(isinstance(grad, NDArray)) if self.lr_scheduler is not None: @@ -199,6 +207,10 @@ def update(self, index, weight, grad, state): lr = self.lr lr *= self.lr_scale.get(index, 1.0) + wd = 0. + if index in self.weight_set: + wd = self.wd + grad = grad * self.rescale_grad if self.clip_gradient is not None: grad = clip(grad, -self.clip_gradient, self.clip_gradient) @@ -206,7 +218,7 @@ def update(self, index, weight, grad, state): if state: mom = state mom[:] *= self.momentum - mom[:] += -lr * (grad + self.wd * weight) + mom[:] += -lr * (grad + wd * weight) weight[:] += mom else: assert self.momentum == 0.0 From c54d52ba0123c94f83b743b58a717f4e8f251ee9 Mon Sep 17 00:00:00 2001 From: shijianping Date: Sun, 3 Jan 2016 17:19:23 +0800 Subject: [PATCH 19/32] fix lint error update mshadow fix lint --- mshadow | 2 +- src/operator/convolution-inl.h | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mshadow b/mshadow index 8b6d2c8b148e..26e3aa1f9a45 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit 8b6d2c8b148ed47a310a5a5ab5798bdd058f4f35 +Subproject commit 26e3aa1f9a4519230526668f653bf67001131942 diff --git a/src/operator/convolution-inl.h b/src/operator/convolution-inl.h index 4234226bbaf4..aef223ee4b35 100644 --- a/src/operator/convolution-inl.h +++ b/src/operator/convolution-inl.h @@ -336,8 +336,10 @@ class ConvolutionProp : public OperatorProperty { CHECK(ksize_x <= dshape[3] && ksize_y <= dshape[2]) << "kernel size exceed input"; (*out_shape)[conv::kOut][1] = param_.num_filter; - (*out_shape)[conv::kOut][2] = (dshape[2] + 2 * param_.pad[0] - (param_.dilate[0] == 1 ? ksize_y : ksize_y * param_.dilate[0] - 1)) / param_.stride[0] + 1; - (*out_shape)[conv::kOut][3] = (dshape[3] + 2 * param_.pad[1] - (param_.dilate[1] == 1 ? ksize_x : ksize_x * param_.dilate[1] - 1)) / param_.stride[1] + 1; + (*out_shape)[conv::kOut][2] = (dshape[2] + 2 * param_.pad[0] - + (param_.dilate[0] == 1 ? ksize_y : ksize_y * param_.dilate[0] - 1)) / param_.stride[0] + 1; + (*out_shape)[conv::kOut][3] = (dshape[3] + 2 * param_.pad[1] - + (param_.dilate[1] == 1 ? ksize_x : ksize_x * param_.dilate[1] - 1)) / param_.stride[1] + 1; return true; } From 725d5a8f460374e0a0e21dfd3628998dc31a18dd Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Fri, 1 Jan 2016 16:51:18 -0700 Subject: [PATCH 20/32] fix wd --- dmlc-core | 2 +- mshadow | 2 +- ps-lite | 2 +- python/mxnet/optimizer.py | 42 +++++++++++++++++++++++++-------------- 4 files changed, 30 insertions(+), 18 deletions(-) diff --git a/dmlc-core b/dmlc-core index 98879773f062..ec454218564f 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 98879773f062e1c5b8a380e002c25672f2b48b13 +Subproject commit ec454218564fee8e531aee02b8943a4634330ce1 diff --git a/mshadow b/mshadow index 47521c6f8e0d..01ce2c5d5214 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit 47521c6f8e0d62a0224bc5bb19b60cc6a0d6a95c +Subproject commit 01ce2c5d5214847b59ef4980e29c08179ab1d518 diff --git a/ps-lite b/ps-lite index 1b0d1df2d297..e86dac79d4ae 160000 --- a/ps-lite +++ b/ps-lite @@ -1 +1 @@ -Subproject commit 1b0d1df2d2971414e8e2bafd2bfe65ed6965baaf +Subproject commit e86dac79d4ae93274af8abf6c20c91dc118dc9a2 diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index b321424954ff..86fddc0ab88e 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -86,11 +86,23 @@ def _init_cc_optimizer(name, param_keys, param_vals): ctypes.byref(handle))) return handle - def __init__(self, rescale_grad=1): + def __init__(self, rescale_grad=1, arg_names=None): self.rescale_grad = rescale_grad self.lr_scale = {} self.num_update = 0 self._index_update_count = {} + self.specialized = False + self.weight_set = set([]) + if arg_names is not None: + self.specialized = True + index = 0 + for name in arg_names: + if name.endswith('data') or name.endswith('label'): + continue + elif name.endswith("weight"): + self.weight_set.add(index) + index += 1 + def create_state(self, index, weight): """Create additional optimizer state such as momentum. @@ -152,21 +164,15 @@ class SGD(Optimizer): def __init__(self, learning_rate=0.01, momentum=0.0, wd=0.0001, rescale_grad=1, clip_gradient=None, lr_scheduler=None, arg_names=None): - super(SGD, self).__init__(rescale_grad) + super(SGD, self).__init__(rescale_grad, arg_names) self.lr = learning_rate self.momentum = momentum self.wd = wd self.clip_gradient = clip_gradient self.lr_scheduler = lr_scheduler - self.weight_set = set([]) if lr_scheduler is not None: self.lr_scheduler.base_lr = learning_rate - if arg_names is not None: - for idx, name in enumerate(arg_names): - if name.endswith("weight"): - self.weight_set.add(idx) - def create_state(self, index, weight): """Create additional optimizer state such as momentum. @@ -207,9 +213,11 @@ def update(self, index, weight, grad, state): lr = self.lr lr *= self.lr_scale.get(index, 1.0) - wd = 0. - if index in self.weight_set: - wd = self.wd + wd = self.wd + if self.specialized == True: + wd = 0. + if index in self.weight_set: + wd = self.wd grad = grad * self.rescale_grad if self.clip_gradient is not None: @@ -283,7 +291,6 @@ def update(self, index, weight, grad, state): state : NDArray or other objects returned by init_state The auxiliary state used in optimization. """ - # TODO(bing) implement wd_bias, wd_gamma, wd_beta assert(isinstance(weight, NDArray)) assert(isinstance(grad, NDArray)) if self.lr_scheduler is not None: @@ -450,8 +457,8 @@ class RMSProp(Optimizer): def __init__(self, learning_rate=0.002, gamma1=0.95, gamma2=0.9, wd=0., rescale_grad=1, clip_gradient=None, - lr_scheduler=None): - super(RMSProp, self).__init__(rescale_grad) + lr_scheduler=None, arg_names=None): + super(RMSProp, self).__init__(rescale_grad, arg_names) self.lr = learning_rate self.gamma1 = gamma1 self.gamma2 = gamma2 @@ -490,12 +497,17 @@ def update(self, index, weight, grad, state): lr = self.lr lr *= self.lr_scale.get(index, 1.0) n, g, delta = state + wd = self.wd + if self.specialized == True: + wd = 0. + if index in self.weight_set: + wd = self.wd grad = grad * self.rescale_grad if self.clip_gradient is not None: grad = clip(grad, -self.clip_gradient, self.clip_gradient) n[:] = (1 - self.gamma1) * (grad * grad) + self.gamma1 * n g[:] = (1 - self.gamma1) * grad + self.gamma1 * g - delta[:] = (self.gamma2) * delta - lr * (grad/sqrt(n - g*g + 1e-4) + self.wd * weight) + delta[:] = (self.gamma2) * delta - lr * (grad/sqrt(n - g*g + 1e-4) + wd * weight) weight[:] += delta @register class Test(Optimizer): From 95f6ed9344b0a190cc1908c16a3cf838437a6906 Mon Sep 17 00:00:00 2001 From: shijianping Date: Sun, 3 Jan 2016 21:52:40 +0800 Subject: [PATCH 21/32] fix deconvolution dilate default setting fix lint error --- src/operator/deconvolution-inl.h | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/src/operator/deconvolution-inl.h b/src/operator/deconvolution-inl.h index 0fa0e6679c06..bd2cc41a7f4d 100644 --- a/src/operator/deconvolution-inl.h +++ b/src/operator/deconvolution-inl.h @@ -30,7 +30,6 @@ namespace deconv { struct DeconvolutionParam : public dmlc::Parameter { TShape kernel; TShape stride; - TShape dilate; TShape pad; uint32_t num_filter; uint32_t num_group; @@ -41,8 +40,6 @@ struct DeconvolutionParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(kernel).describe("deconvolution kernel size: (y, x)"); DMLC_DECLARE_FIELD(stride).set_default(TShape(shape, shape + 2)) .describe("deconvolution stride: (y, x)"); - DMLC_DECLARE_FIELD(dilate).set_default(TShape(shape, shape + 2)) - .describe("deconvolution dilate: (y, x)"); shape[0] = shape[1] = 0; DMLC_DECLARE_FIELD(pad).set_default(TShape(shape, shape + 2)) .describe("pad for deconvolution: (y, x)"); @@ -108,8 +105,7 @@ class DeconvolutionOp : public Operator { param_.kernel[1], param_.stride[0], param_.stride[1], - param_.dilate[0], - param_.dilate[1]); + 1, 1); // Deconvolution only support dilate equals 1 } else { temp_col = unpack_patch2col(pad(out.Slice(i, i + step), param_.pad[0], param_.pad[1]), @@ -117,8 +113,7 @@ class DeconvolutionOp : public Operator { param_.kernel[1], param_.stride[0], param_.stride[1], - param_.dilate[0], - param_.dilate[1]); + 1, 1); // Deconvolution only support dilate equals 1 } const index_t gstride = temp_col.size(0) / param_.num_group; for (uint32_t gid = 0; gid < param_.num_group; ++gid) { @@ -132,7 +127,7 @@ class DeconvolutionOp : public Operator { param_.kernel[0], param_.kernel[1], param_.stride[0], - param_.dilate[0]); + 1); // Deconvolution only support dilate equals 1 } else { Shape<4> pshape = out.Slice(i, i + step).shape_; pshape[2] += 2 * param_.pad[0]; @@ -142,7 +137,7 @@ class DeconvolutionOp : public Operator { param_.kernel[0], param_.kernel[1], param_.stride[0], - param_.dilate[0]), + 1), // Deconvolution only support dilate equals 1 out[i][0].shape_); } } @@ -202,16 +197,14 @@ class DeconvolutionOp : public Operator { param_.kernel[1], param_.stride[0], param_.stride[1], - param_.dilate[0], - param_.dilate[1]); + 1, 1); // Deconvolution only support dilate equals 1 } else { temp_col = unpack_patch2col(pad(grad.Slice(i, i + step), param_.pad[0], param_.pad[1]), param_.kernel[0], param_.kernel[1], param_.stride[0], param_.stride[1], - param_.dilate[0], - param_.dilate[1]); + 1, 1); // Deconvolution only support dilate equals 1 } const index_t gstride = temp_col.size(0) / param_.num_group; for (uint32_t gid = 0; gid < param_.num_group; ++gid) { @@ -329,8 +322,6 @@ class DeconvolutionProp : public OperatorProperty { << "incorrect kernel size: " << param_.kernel; CHECK_GE(param_.stride.Size(), 0) \ << "incorrect stride size: " << param_.stride; - CHECK_EQ(param_.dilate.Size(), 1) \ - << "Dilate not supported in deconvolution, incorrect stride size: " << param_.stride; (*out_shape)[deconv::kOut][1] = param_.num_filter; (*out_shape)[deconv::kOut][2] = param_.stride[0] * (dshape[2] - 1) + ksize_y - 2 * param_.pad[0]; From cfdcd9c7d0389485550625e758bd0c27f0f9a723 Mon Sep 17 00:00:00 2001 From: hiraditya Date: Sun, 3 Jan 2016 10:26:16 -0600 Subject: [PATCH 22/32] Reorder ldflags --- example/cpp/Makefile | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/example/cpp/Makefile b/example/cpp/Makefile index f8a85278a2c0..1de957de4f75 100644 --- a/example/cpp/Makefile +++ b/example/cpp/Makefile @@ -1,11 +1,16 @@ CFLAGS=-I ../../include -Wall -O3 -msse3 -funroll-loops -Wno-unused-parameter -Wno-unknown-pragmas -fopenmp -I ../../mshadow -I ../../dmlc-core/include LDFLAGS=-L ../../lib -lmxnet -lopenblas -DMSHADOW_USE_CBLAS=1 -DMSHADOW_USE_MKL=0 -DMSHADOW_USE_CUDA=1 +CC=g++ + mlp: ./mlp.cpp - g++ -std=c++0x $(CFLAGS) $(LDFLAGS) -o $@ $^ + $(CC) -std=c++0x $(CFLAGS) -o $@ $^ $(LDFLAGS) use_ndarray: ./use_ndarray.cpp - g++ -std=c++0x $(CFLAGS) $(LDFLAGS) -o $@ $^ + $(CC) -std=c++0x $(CFLAGS) -o $@ $^ $(LDFLAGS) lint: python2 ../../dmlc-core/scripts/lint.py mxnet "cpp" ./ + +clean: + rm -f mlp use_ndarray From 9e999f9996696661e7d0f75fa0cb7acd007dbd93 Mon Sep 17 00:00:00 2001 From: hiraditya Date: Sun, 3 Jan 2016 10:32:57 -0600 Subject: [PATCH 23/32] Include CMakeParseArguments --- cmake/Utils.cmake | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake index 342689c8256c..0308645df6b8 100644 --- a/cmake/Utils.cmake +++ b/cmake/Utils.cmake @@ -1,3 +1,6 @@ +# For cmake_parse_arguments +include(CMakeParseArguments) + ################################################################################################ # Command alias for debugging messages # Usage: @@ -395,4 +398,5 @@ function(mxnet_source_group group) file(GLOB_RECURSE srcs2 ${CAFFE_SOURCE_GROUP_GLOB_RECURSE}) source_group(${group} FILES ${srcs2}) endif() -endfunction() \ No newline at end of file +endfunction() + From fbd38131cd0bad96ec7910ff71cef051beec073a Mon Sep 17 00:00:00 2001 From: Yanghao Li Date: Sun, 3 Jan 2016 19:41:49 -0800 Subject: [PATCH 24/32] fix scale in Xavier initializer --- python/mxnet/initializer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index ef371b190808..987c0fe104b3 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -189,7 +189,7 @@ class Xavier(Initializer): def __init__(self, rnd_type="uniform", factor_type="avg", magnitude=3): self.rnd_type = rnd_type self.factor_type = factor_type - self.magnitude = magnitude + self.magnitude = float(magnitude) def _init_weight(self, _, arr): @@ -197,7 +197,7 @@ def _init_weight(self, _, arr): fan_in, fan_out = np.prod(shape[1:]), shape[0] factor = 1 if self.factor_type == "avg": - factor = (fan_in + fan_out) / 2 + factor = (fan_in + fan_out) / 2.0 elif self.factor_type == "in": factor = fan_in elif self.factor_type == "out": From 222ee2d5e406e32cfeda7f6b8014f22853ec276b Mon Sep 17 00:00:00 2001 From: Aditya Kumar Date: Sun, 3 Jan 2016 23:19:35 -0600 Subject: [PATCH 25/32] CC to CXX --- example/cpp/Makefile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/example/cpp/Makefile b/example/cpp/Makefile index 1de957de4f75..dc61757126d1 100644 --- a/example/cpp/Makefile +++ b/example/cpp/Makefile @@ -1,13 +1,13 @@ CFLAGS=-I ../../include -Wall -O3 -msse3 -funroll-loops -Wno-unused-parameter -Wno-unknown-pragmas -fopenmp -I ../../mshadow -I ../../dmlc-core/include LDFLAGS=-L ../../lib -lmxnet -lopenblas -DMSHADOW_USE_CBLAS=1 -DMSHADOW_USE_MKL=0 -DMSHADOW_USE_CUDA=1 -CC=g++ +CXX=g++ mlp: ./mlp.cpp - $(CC) -std=c++0x $(CFLAGS) -o $@ $^ $(LDFLAGS) + $(CXX) -std=c++0x $(CFLAGS) -o $@ $^ $(LDFLAGS) use_ndarray: ./use_ndarray.cpp - $(CC) -std=c++0x $(CFLAGS) -o $@ $^ $(LDFLAGS) + $(CXX) -std=c++0x $(CFLAGS) -o $@ $^ $(LDFLAGS) lint: python2 ../../dmlc-core/scripts/lint.py mxnet "cpp" ./ From b02df0028ab01a758e7e930d34feefd8b2c7d34b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?ruixinzhang=28=E5=BC=A0=E7=9D=BF=E6=AC=A3=29?= Date: Mon, 4 Jan 2016 19:36:17 +0800 Subject: [PATCH 26/32] Add Interpolation Options Add interpolation method selection options for tools/im2rec and src/io/image_augmenter --- Makefile | 2 +- src/io/image_augmenter.h | 37 +++++++++++++++- tools/im2rec.cc | 96 +++++++++++++++++++++++++++++++--------- 3 files changed, 111 insertions(+), 24 deletions(-) diff --git a/Makefile b/Makefile index 63fa524b016f..08356bb02f79 100644 --- a/Makefile +++ b/Makefile @@ -165,7 +165,7 @@ bin/im2rec: tools/im2rec.cc $(ALL_DEP) $(BIN) : @mkdir -p $(@D) - $(CXX) $(CFLAGS) -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) + $(CXX) $(CFLAGS) -std=c++0x -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) include tests/cpp/unittest.mk diff --git a/src/io/image_augmenter.h b/src/io/image_augmenter.h index 5c7c7c2359d5..a359b99d6246 100644 --- a/src/io/image_augmenter.h +++ b/src/io/image_augmenter.h @@ -54,6 +54,8 @@ struct ImageAugmentParam : public dmlc::Parameter { int rotate; /*! \brief filled color while padding */ int fill_value; + /*! \brief interpolation method 0-NN 1-bilinear 2-cubic 3-area 4-lanczos4 9-auto 10-rand */ + int inter_method; /*! \brief shape of the image data*/ TShape data_shape; // declare parameters @@ -95,6 +97,8 @@ struct ImageAugmentParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(data_shape) .set_expect_ndim(3).enforce_nonzero() .describe("Dataset Param: Shape of each instance generated by the DataIter."); + DMLC_DECLARE_FIELD(inter_method).set_default(1) + .describe("Augmentation Param: 0-NN 1-bilinear 2-cubic 3-area 4-lanczos4 9-auto 10-rand."); } }; @@ -125,6 +129,27 @@ class ImageAugmenter { } } } + /*! + *\brief get interpolation method with given inter_method, 0-CV_INTER_NN 1-CV_INTER_LINEAR 2-CV_INTER_CUBIC + *\ 3-CV_INTER_AREA 4-CV_INTER_LANCZOS4 9-AUTO(cubic for enlarge, area for shrink, bilinear for others) 10-RAND + */ + virtual int GetInterMethod(int inter_method, int old_width, int old_height, int new_width, + int new_height, common::RANDOM_ENGINE *prnd) { + if (inter_method == 9) { + if (new_width > old_width && new_height > old_height) { + return 2; // CV_INTER_CUBIC for enlarge + } else if (new_width rand_uniform_int(0, 4); + return rand_uniform_int(*prnd); + } else { + return inter_method; + } + } #if MXNET_USE_OPENCV #ifdef _MSC_VER #define M_PI CV_PI @@ -181,8 +206,13 @@ class ImageAugmenter { float ori_center_height = M.at(1, 0) * src.cols + M.at(1, 1) * src.rows; M.at(0, 2) = (new_width - ori_center_width) / 2; M.at(1, 2) = (new_height - ori_center_height) / 2; + CHECK((param_.inter_method >= 1 && param_.inter_method <= 4) || + (param_.inter_method >= 9 && param_.inter_method <= 10)) + << "invalid inter_method: valid value 0,1,2,3,9,10"; + int interpolation_method = GetInterMethod(param_.inter_method, + src.cols, src.rows, new_width, new_height, prnd); cv::warpAffine(src, temp_, M, cv::Size(new_width, new_height), - cv::INTER_LINEAR, + interpolation_method, cv::BORDER_CONSTANT, cv::Scalar(param_.fill_value, param_.fill_value, param_.fill_value)); res = temp_; @@ -206,7 +236,10 @@ class ImageAugmenter { y /= 2; x /= 2; } cv::Rect roi(x, y, rand_crop_size, rand_crop_size); - cv::resize(res(roi), res, cv::Size(param_.data_shape[2], param_.data_shape[1])); + int interpolation_method = GetInterMethod(param_.inter_method, rand_crop_size, rand_crop_size, + param_.data_shape[2], param_.data_shape[1], prnd); + cv::resize(res(roi), res, cv::Size(param_.data_shape[2], param_.data_shape[1]) + , 0, 0, interpolation_method); } else { CHECK(static_cast(res.rows) >= param_.data_shape[1] && static_cast(res.cols) >= param_.data_shape[2]) diff --git a/tools/im2rec.cc b/tools/im2rec.cc index fb8599ca471d..408cf9bc35a5 100644 --- a/tools/im2rec.cc +++ b/tools/im2rec.cc @@ -21,7 +21,27 @@ #include #include #include "../src/io/image_recordio.h" - +#include +/*! + *\brief get interpolation method with given inter_method, 0-CV_INTER_NN 1-CV_INTER_LINEAR 2-CV_INTER_CUBIC + *\ 3-CV_INTER_AREA 4-CV_INTER_LANCZOS4 9-AUTO(cubic for enlarge, area for shrink, bilinear for others) 10-RAND(0-4) + */ +int GetInterMethod(int inter_method, int old_width, int old_height, int new_width, int new_height, std::mt19937& prnd) { + if (inter_method == 9) { + if (new_width > old_width && new_height > old_height) { + return 2; // CV_INTER_CUBIC for enlarge + } else if (new_width rand_uniform_int(0, 4); + return rand_uniform_int(prnd); + } else { + return inter_method; + } +} int main(int argc, char *argv[]) { if (argc < 4) { printf("Usage: [additional parameters in form key=value]\n"\ @@ -34,6 +54,7 @@ int main(int argc, char *argv[]) { "\tcenter_crop=CENTER_CROP[default=0] specify whether to crop the center image to make it square.\n"\ "\tquality=QUALITY[default=80] JPEG quality for encoding (1-100, default: 80) or PNG compression for encoding (1-9, default: 3).\n"\ "\tencoding=ENCODING[default='.jpg'] Encoding type. Can be '.jpg' or '.png'\n"\ + "\tinter_method=INTER_METHOD[default=1] NN(0) BILINEAR(1) CUBIC(2) AREA(3) LANCZOS4(4) AUTO(9) RAND(10).\n"\ "\tunchanged=UNCHANGED[default=0] Keep the original image encoding, size and color. If set to 1, it will ignore the others parameters.\n"); return 0; } @@ -44,7 +65,8 @@ int main(int argc, char *argv[]) { int center_crop = 0; int quality = 80; int color_mode = CV_LOAD_IMAGE_COLOR; - int unchanged=0; + int unchanged = 0; + int inter_method = CV_INTER_LINEAR; std::string encoding(".jpg"); for (int i = 4; i < argc; ++i) { char key[128], val[128]; @@ -58,16 +80,16 @@ int main(int argc, char *argv[]) { if (!strcmp(key, "color")) color_mode = atoi(val); if (!strcmp(key, "encoding")) encoding = std::string(val); if (!strcmp(key, "unchanged")) unchanged = atoi(val); + if (!strcmp(key, "inter_method")) inter_method = atoi(val); } } // Check parameters ranges - if(color_mode != -1 && color_mode != 0 && color_mode != 1) { + if (color_mode != -1 && color_mode != 0 && color_mode != 1) { LOG(FATAL) << "Color mode must be -1, 0 or 1."; } - if(encoding != std::string(".jpg") && encoding != std::string(".png")) { + if (encoding != std::string(".jpg") && encoding != std::string(".png")) { LOG(FATAL) << "Encoding mode must be .jpg or .png."; } - if (new_size > 0) { LOG(INFO) << "New Image Size: Short Edge " << new_size; } else { @@ -81,13 +103,39 @@ int main(int argc, char *argv[]) { } if (color_mode == -1) { LOG(INFO) << "Keep original color mode"; - } + } LOG(INFO) << "Encoding is " << encoding; - if(encoding == std::string(".png") and quality > 9) { + if (encoding == std::string(".png") and quality > 9) { quality = 3; } - + if (inter_method != 1) { + switch (inter_method) { + case 0: + LOG(INFO) << "Use inter_method CV_INTER_NN"; + break; + case 2: + LOG(INFO) << "Use inter_method CV_INTER_CUBIC"; + break; + case 3: + LOG(INFO) << "Use inter_method CV_INTER_AREA"; + break; + case 4: + LOG(INFO) << "Use inter_method CV_INTER_LANCZOS4"; + break; + case 9: + LOG(INFO) << "Use inter_method mod auto(cubic for enlarge, area for shrink)"; + break; + case 10: + LOG(INFO) << "Use inter_method mod rand(nn/bilinear/cubic/area/lanczos4)"; + break; + default: + LOG(INFO) << "Unkown inter_method"; + return 0; + } + } + std::random_device rd; + std::mt19937 prnd(rd()); using namespace dmlc; const static size_t kBufferSize = 1 << 20UL; std::string root = argv[2]; @@ -95,7 +143,7 @@ int main(int argc, char *argv[]) { size_t imcnt = 0; double tstart = dmlc::GetTime(); dmlc::InputSplit *flist = dmlc::InputSplit:: - Create(argv[1], partid, nsplit, "text"); + Create(argv[1], partid, nsplit, "text"); std::ostringstream os; if (nsplit == 1) { os << argv[3]; @@ -110,12 +158,11 @@ int main(int argc, char *argv[]) { std::vector decode_buf; std::vector encode_buf; std::vector encode_params; - if(encoding == std::string(".png")) { + if (encoding == std::string(".png")) { encode_params.push_back(CV_IMWRITE_PNG_COMPRESSION); encode_params.push_back(quality); LOG(INFO) << "PNG encoding compression: " << quality; - } - else { + } else { encode_params.push_back(CV_IMWRITE_JPEG_QUALITY); encode_params.push_back(quality); LOG(INFO) << "JPEG encoding quality: " << quality; @@ -126,7 +173,7 @@ int main(int argc, char *argv[]) { std::string sline(static_cast(line.dptr), line.size); std::istringstream is(sline); if (!(is >> rec.header.image_id[0] >> rec.header.label)) continue; - for (int k = 1; k < label_width; ++ k) { + for (int k = 1; k < label_width; ++k) { float tmp; CHECK(is >> tmp) << "Invalid ImageList, did you provide the correct label_width?"; @@ -154,8 +201,7 @@ int main(int argc, char *argv[]) { if (nread != kBufferSize) break; } delete fi; - - if(unchanged != 1) { + if (unchanged != 1) { cv::Mat img = cv::imdecode(decode_buf, color_mode); CHECK(img.data != NULL) << "OpenCV decode fail:" << path; cv::Mat res = img; @@ -169,12 +215,21 @@ int main(int argc, char *argv[]) { img = img(cv::Range(0, img.rows), cv::Range(margin, margin + img.rows)); } } + int interpolation_method = 1; if (img.rows > img.cols) { - cv::resize(img, res, cv::Size(new_size, img.rows * new_size / img.cols), - 0, 0, CV_INTER_LINEAR); + if (img.cols != new_size) { + interpolation_method = GetInterMethod(inter_method, img.cols, img.rows, new_size, img.rows * new_size / img.cols, prnd); + cv::resize(img, res, cv::Size(new_size, img.rows * new_size / img.cols), 0, 0, interpolation_method); + } else { + res = img.clone(); + } } else { - cv::resize(img, res, cv::Size(new_size * img.cols / img.rows, new_size), - 0, 0, CV_INTER_LINEAR); + if (img.rows != new_size) { + interpolation_method = GetInterMethod(inter_method, img.cols, img.rows, new_size * img.cols / img.rows, new_size, prnd); + cv::resize(img, res, cv::Size(new_size * img.cols / img.rows, new_size), 0, 0, interpolation_method); + } else { + res = img.clone(); + } } } encode_buf.clear(); @@ -183,8 +238,7 @@ int main(int argc, char *argv[]) { blob.resize(bsize + encode_buf.size()); memcpy(BeginPtr(blob) + bsize, BeginPtr(encode_buf), encode_buf.size()); - } - else { + } else { size_t bsize = blob.size(); blob.resize(bsize + decode_buf.size()); memcpy(BeginPtr(blob) + bsize, From 396c89581bacf3bb6ed1a3d01dbaa11fadfa4090 Mon Sep 17 00:00:00 2001 From: Ruixiang Zhang Date: Sat, 2 Jan 2016 02:34:48 +0800 Subject: [PATCH 27/32] add accnn tool --- tools/accnn/README.md | 83 +++++++++++++++++++++++ tools/accnn/acc_conv.py | 77 +++++++++++++++++++++ tools/accnn/acc_fc.py | 57 ++++++++++++++++ tools/accnn/accnn.py | 37 +++++++++++ tools/accnn/config.json | 21 ++++++ tools/accnn/rank_selection.py | 86 ++++++++++++++++++++++++ tools/accnn/utils.py | 122 ++++++++++++++++++++++++++++++++++ 7 files changed, 483 insertions(+) create mode 100644 tools/accnn/README.md create mode 100644 tools/accnn/acc_conv.py create mode 100644 tools/accnn/acc_fc.py create mode 100644 tools/accnn/accnn.py create mode 100644 tools/accnn/config.json create mode 100644 tools/accnn/rank_selection.py create mode 100644 tools/accnn/utils.py diff --git a/tools/accnn/README.md b/tools/accnn/README.md new file mode 100644 index 000000000000..02f10d111e2d --- /dev/null +++ b/tools/accnn/README.md @@ -0,0 +1,83 @@ +# Accelerate Convolutional Neural Networks + +This tool aims to accelerate the test-time computation and decrease number of parameters of deep CNNs. + + +## How to use + +Use ``accnn.py`` to get a new model by specifying an original model and the speeding-up ratio. + +You may provide a json to explicitly control the architecture of the new model, otherwise the rank-selection algorithm would be used to do it automatically and the configuration would be saved to file ``config.json``. + +``acc_conv.py`` and ``acc_fc.py`` would be involved automatically when using ``accnn.py`` while ``acc_conv.py`` and ``acc_fc.py`` can also be used seperately. + +## Example + +###Speedup whole network + +- Speed up a model by 2 times and use ``rank-selection`` to determine ranks of each layer automatically + + ```bash + python accnn.py -m MODEL-PREFIX --save-model new-vgg16 --ratio 2 + ``` + +- Use your own configuration file without ``rank-selection`` + + ```bash + python accnn.py -m MODEL-PREFIX --save-model new-model --config YOUR-CONFIG_JSON + ``` + +###Speedup a single layer + +- Decompose a convolutional layer: + + ```bash + python acc_conv.py -m MODEL-PREFIX --layer LAYER-NAME --K NUM-FILTER --save-model new-model + ``` + +- Decompose a fullyconnected layer: + + ```bash + python acc_fc.py -m MODEL-PREFIX --layer LAYER-NAME --K NUM-HIDDEN --save-model new-model + ``` +- uses `--help` to see more options + + +## Results + +The experiments are carried on a single machine with four Nvidia Titan X GPUs. The top-5 accuracy is evaluated on ImageNet validation dataset. + + + +| Model | Top-5 accuracy | Theoretical speed up | CPU speed up | GPU speed up | +| ------------- | -----------: | -------------: | -----------: | -----------: | +| model0 | 89.6% | 1x| 1x| 1x| +| model1 | 88.6% | 2.4x| 2.2x| 1.1x| +| model2 | 89.8% | 2.4x| 2.2x| 1.1x| +| model3 | 87.5% | 3x| 2.6x| 1.2x| +| model4 | 89.6% | 3x| 2.6x| 1.2x| + + + * ``model0`` is the original VGG16 model directly converted from Caffe Model Zoo + * ``model1`` is the accelerated model based on ``config.json`` + * ``model2`` is the same as ``model1`` but is fine-tuned on ImageNet training dataset for 5 epochs + * ``model3`` is the accelerated model based on rank-selection with 3 times speeding up + * ``model4`` is the same as ``model3`` but is fine-tuned on ImageNet training dataset for 5 epochs + * The experiments in GPU are carried with cuDNN 4 + + +## Notes + +* This tool is verified on the [VGG-16](https://gist.github.com/jimmie33/27c1c0a7736ba66c2395) model converted from Caffe by ``caffe_converter`` tool. + +* ``accnn.py`` tool only supports single input and output + +* This tool mainly implements the algorithm of Cheng *et al.* [2] to decompose a convolutional layer to two convolutional layers both in spatial dimensions and across channels. ``acc_conv.py`` provides the function to replace a ``(N,d,d)`` conv. layer by two ``(K,d,1)`` and ``(N,1,d)`` conv. layers. + +* The idea of ``rank-selection`` tool is based on the related work of Zhang *et al* [1] that we could use the product of PCA energy to determine the rank for each layer. + +## Reference Paper + +[1] Zhang, Xiangyu, et al. "Efficient and accurate approximations of nonlinear convolutional networks." arXiv preprint arXiv:1411.4229 (2014). + +[2] Tai, Cheng, et al. "Convolutional neural networks with low-rank regularization." arXiv preprint arXiv:1511.06067 (2015). diff --git a/tools/accnn/acc_conv.py b/tools/accnn/acc_conv.py new file mode 100644 index 000000000000..8f468def14fc --- /dev/null +++ b/tools/accnn/acc_conv.py @@ -0,0 +1,77 @@ +import numpy as np +from scipy import linalg as LA +import mxnet as mx +import argparse +import utils + +def conv_vh_decomposition(model, args): + W = model.arg_params[args.layer+'_weight'].asnumpy() + N, C, y, x = W.shape + b = model.arg_params[args.layer+'_bias'].asnumpy() + W = W.transpose((1,2,0,3)).reshape((C*y, -1)) + + U, D, Q = np.linalg.svd(W, full_matrices=False) + sqrt_D = LA.sqrtm(np.diag(D)) + K = args.K + V = U[:,:K].dot(sqrt_D[:K, :K]) + H = Q.T[:,:K].dot(sqrt_D[:K, :K]) + V = V.T.reshape(K, C, y, 1) + b_1 = np.zeros((K, )) + H = H.reshape(N, x, 1, K).transpose((0,3,2,1)) + b_2 = b + + W1, b1, W2, b2 = V, b_1, H, b_2 + def sym_handle(data, node): + kernel = eval(node['param']['kernel']) + pad = eval(node['param']['pad']) + name = node['name'] + + name1 = name + '_v' + kernel1 = tuple((kernel[0], 1)) + pad1 = tuple((pad[0], 0)) + num_filter = W1.shape[0] + sym1 = mx.symbol.Convolution(data=data, kernel=kernel1, pad=pad1, num_filter=num_filter, name=name1) + + name2 = name + '_h' + kernel2 = tuple((1, kernel[1])) + pad2 = tuple((0, pad[1])) + num_filter = W2.shape[0] + sym2 = mx.symbol.Convolution(data=sym1, kernel=kernel2, pad=pad2, num_filter=num_filter, name=name2) + return sym2 + + def arg_handle(arg_shape_dic, arg_params): + name1 = args.layer + '_v' + name2 = args.layer + '_h' + weight1 = mx.ndarray.array(W1) + bias1 = mx.ndarray.array(b1) + weight2 = mx.ndarray.array(W2) + bias2 = mx.ndarray.array(b2) + assert weight1.shape == arg_shape_dic[name1+'_weight'], 'weight1' + assert weight2.shape == arg_shape_dic[name2+'_weight'], 'weight2' + assert bias1.shape == arg_shape_dic[name1+'_bias'], 'bias1' + assert bias2.shape == arg_shape_dic[name2+'_bias'], 'bias2' + + arg_params[name1 + '_weight'] = weight1 + arg_params[name1 + '_bias'] = bias1 + arg_params[name2 + '_weight'] = weight2 + arg_params[name2 + '_bias'] = bias2 + + new_model = utils.replace_conv_layer(args.layer, model, sym_handle, arg_handle) + return new_model + +def main(): + model = utils.load_model(args) + new_model = conv_vh_decomposition(model, args) + new_model.save(args.save_model) + +if __name__ == '__main__': + parser=argparse.ArgumentParser() + parser.add_argument('-m', '--model', help='the model to speed up') + parser.add_argument('-g', '--gpus', default='0,1,2,3', help='the gpus to be used in ctx') + parser.add_argument('--load-epoch',type=int,default=1) + parser.add_argument('--layer') + parser.add_argument('--K', type=int) + parser.add_argument('--save-model') + args = parser.parse_args() + main() + \ No newline at end of file diff --git a/tools/accnn/acc_fc.py b/tools/accnn/acc_fc.py new file mode 100644 index 000000000000..a7b7da163990 --- /dev/null +++ b/tools/accnn/acc_fc.py @@ -0,0 +1,57 @@ +import numpy as np +from scipy import linalg as LA +import mxnet as mx +import argparse +import utils +import pdb + +def fc_decomposition(model, args): + W = model.arg_params[args.layer+'_weight'].asnumpy() + b = model.arg_params[args.layer+'_bias'].asnumpy() + W = W.reshape((W.shape[0],-1)) + b = b.reshape((b.shape[0],-1)) + u, s, v = LA.svd(W, full_matrices=False) + s = np.diag(s) + t = u.dot(s.dot(v)) + rk = args.K + P = u[:,:rk] + Q = s[:rk,:rk].dot(v[:rk,:]) + + name1 = args.layer + '_red' + name2 = args.layer + '_rec' + def sym_handle(data, node): + W1, W2 = Q, P + sym1 = mx.symbol.FullyConnected(data=data, num_hidden=W1.shape[0], no_bias=True, name=name1) + sym2 = mx.symbol.FullyConnected(data=sym1, num_hidden=W2.shape[0], no_bias=False, name=name2) + return sym2 + + def arg_handle(arg_shape_dic, arg_params): + W1, W2 = Q, P + W1 = W1.reshape(arg_shape_dic[name1+'_weight']) + weight1 = mx.ndarray.array(W1) + W2 = W2.reshape(arg_shape_dic[name2+'_weight']) + b2 = b.reshape(arg_shape_dic[name2+'_bias']) + weight2 = mx.ndarray.array(W2) + bias2 = mx.ndarray.array(b2) + arg_params[name1 + '_weight'] = weight1 + arg_params[name2 + '_weight'] = weight2 + arg_params[name2 + '_bias'] = bias2 + + new_model = utils.replace_conv_layer(args.layer, model, sym_handle, arg_handle) + return new_model + +def main(): + model = utils.load_model(args) + new_model = fc_decomposition(model, args) + new_model.save(args.save_model) + +if __name__ == '__main__': + parser=argparse.ArgumentParser() + parser.add_argument('-m', '--model', help='the model to speed up') + parser.add_argument('-g', '--gpus', default='0,1,2,3', help='the gpus to be used in ctx') + parser.add_argument('--load-epoch',type=int,default=1) + parser.add_argument('--layer') + parser.add_argument('--K', type=int) + parser.add_argument('--save-model') + args = parser.parse_args() + main() diff --git a/tools/accnn/accnn.py b/tools/accnn/accnn.py new file mode 100644 index 000000000000..a5e3c8fdd5bf --- /dev/null +++ b/tools/accnn/accnn.py @@ -0,0 +1,37 @@ +import mxnet as mx +import argparse +import utils +import acc_conv +import acc_fc +import rank_selection +import collections +import json +import sys + +parser = argparse.ArgumentParser() +parser.add_argument('-m', '--model', help='the model to speed up') +parser.add_argument('-g', '--gpus', default='0', help='the gpus will be used, e.g "0,1,2,3"') +parser.add_argument('--load-epoch',type=int, default=1, help="load the model on an epoch using the model-prefix") +parser.add_argument('--save-model', help='output model prefix') +parser.add_argument('--config', default=None, help='specify the config file') +parser.add_argument('--ratio', type=float, default=2, help='speed up ratio') +args = parser.parse_args() + +model = utils.load_model(args) +if args.config: + args.config = json.load(open(args.config, 'r')) +else: + config = {} + config['conv_params'] = rank_selection.get_ranksel(model, args.ratio) + config['fc_params'] = {} + json.dump(config, open('config-rksel-%.1f.json'%(args.ratio), 'w'), indent=2) + +new_model = model +Args = collections.namedtuple('ConvArgs', 'layer K') +for layer, K in args.config['conv_params'].iteritems(): + arg = Args(layer=layer, K=K) + new_model = acc_conv.conv_vh_decomposition(new_model, arg) +for layer, K in args.config['fc_params'].iteritems(): + arg = Args(layer=layer, K=K) + new_model = acc_fc.fc_decomposition(new_model, arg) +new_model.save(args.save_model, 1) diff --git a/tools/accnn/config.json b/tools/accnn/config.json new file mode 100644 index 000000000000..8e086fb65e10 --- /dev/null +++ b/tools/accnn/config.json @@ -0,0 +1,21 @@ +{ + "conv_params": { + "conv1_1": 5, + "conv1_2": 32, + "conv2_1": 64, + "conv2_2": 64, + "conv3_1": 96, + "conv3_2": 160, + "conv3_3": 192, + "conv4_1": 256, + "conv4_2": 256, + "conv4_3": 320, + "conv5_1": 384, + "conv5_2": 384, + "conv5_3": 384 + }, + "fc_params": { + "fc6": 2048, + "fc7": 2048 + } +} \ No newline at end of file diff --git a/tools/accnn/rank_selection.py b/tools/accnn/rank_selection.py new file mode 100644 index 000000000000..57e3bcc8acd1 --- /dev/null +++ b/tools/accnn/rank_selection.py @@ -0,0 +1,86 @@ +import numpy as np +import mxnet as mx +import json +import utils +import math +import sys + +def calc_complexity(ishape, node): + y, x = map(int, eval(node['param']['kernel'])) + N = int(node['param']['num_filter']) + C, Y, X = ishape + return x*(N+C)*X*Y, x*y*N*C*X*Y + +def calc_eigenvalue(model, node): + W = model.arg_params[node['name'] + '_weight'].asnumpy() + N, C, y, x = W.shape + W = W.transpose((1,2,0,3)).reshape((C*y, -1)) + U, D, Q = np.linalg.svd(W, full_matrices=False) + return D + +def get_ranksel(model, ratio): + conf = json.loads(model.symbol.tojson()) + _, output_shapes, _ = model.symbol.get_internals().infer_shape(data=(1,3,224,224)) + out_names = model.symbol.get_internals().list_outputs() + out_shape_dic = dict(zip(out_names, output_shapes)) + nodes = conf['nodes'] + nodes = utils.topsort(nodes) + C = [] + D = [] + S = [] + conv_names = [] + EC = 0 + for node in nodes: + if node['op'] == 'Convolution': + input_nodes = [nodes[int(j[0])] for j in node['inputs']] + data = [input_node['name'] for input_node in input_nodes\ + if not input_node['name'].startswith(node['name'])][0] + if utils.is_input(node): + ishape = (3, 224, 224) + else: + ishape = out_shape_dic[data + '_output'][1:] + C.append(calc_complexity(ishape, node)) + D.append(int(node['param']['num_filter'])) + S.append(calc_eigenvalue(model, node)) + conv_names.append(node['name']) + EC += C[-1][1] + for s in S: + ss = sum(s) + for i in xrange(1, len(s)): + s[i] += s[i-1] + n = len(C) + EC /= ratio + dp = [{}, {}] + dpc = [{} for _ in xrange(n)] + now, nxt = 0, 1 + dp[now][0] = 0 + for i in xrange(n): + dp[nxt] = {} + sys.stdout.flush() + for now_c, now_v in dp[now].iteritems(): + for d in xrange(min(len(S[i]), D[i])): + nxt_c = now_c + (d+1)*C[i][0] + if nxt_c > EC: + continue + nxt_v = dp[now][now_c] + math.log(S[i][d]) + if dp[nxt].has_key(nxt_c): + if nxt_v > dp[nxt][nxt_c]: + dp[nxt][nxt_c] = nxt_v + dpc[i][nxt_c] = (d,now_c) + else: + dp[nxt][nxt_c] = nxt_v + dpc[i][nxt_c] = (d,now_c) + now, nxt = nxt, now + maxv = -1e9 + target_c = 0 + for c,v in dp[now].iteritems(): + assert c <= EC, 'False' + if v > maxv: + maxv = v + target_c = c + res = [0]*n + nowc = target_c + for i in xrange(n-1,-1,-1): + res[i] = dpc[i][nowc][0] + nowc = dpc[i][nowc][1] + return dict(zip(conv_names, res)) diff --git a/tools/accnn/utils.py b/tools/accnn/utils.py new file mode 100644 index 000000000000..a57a384b1fab --- /dev/null +++ b/tools/accnn/utils.py @@ -0,0 +1,122 @@ +import mxnet as mx +import copy +import json + +def load_model(args): + devs = mx.cpu() if args.gpus == None else [mx.gpu(int(i)) for i in args.gpus.split(',')] + return mx.model.FeedForward.load(args.model, args.load_epoch, ctx=devs) + +def topsort(nodes): + n = len(nodes) + deg = [0]*n + g = [[] for _ in xrange(n)] + for i,node in enumerate(nodes): + if node.has_key('inputs'): + for j in node['inputs']: + deg[i] += 1 + g[j[0]].append(i) + if node['name'] == '': + print node + print '!!!',j[0] + from collections import deque + q = deque([i for i in xrange(n) if deg[i]==0]) + res = [] + for its in xrange(n): + i = q.popleft() + res.append(nodes[i]) + for j in g[i]: + deg[j] -= 1 + if deg[j] == 0: + q.append(j) + new_ids=dict([(node['name'],i) for i,node in enumerate(res)]) + for node in res: + if node.has_key('inputs'): + for j in node['inputs']: + j[0]=new_ids[nodes[j[0]]['name']] + return res + +def is_input(node): + name = node['name'] + return len(node['inputs']) == 0 and ('weight' not in name) and ('bias' not in name) and ('label' not in name) + +def replace_conv_layer(layer_name, old_model, sym_handle, arg_handle): + conf = json.loads(old_model.symbol.tojson()) + sym_dict = {} + nodes = conf['nodes'] + nodes = topsort(nodes) + res_sym = None + new_model = old_model + for i,node in enumerate(nodes): + sym = None + if is_input(node): + sym = mx.symbol.Variable(name='data') + elif node['op'] != 'null': + input_nodes = [nodes[int(j[0])] for j in node['inputs']] + datas = [input_node['name'] for input_node in input_nodes\ + if not input_node['name'].startswith(node['name'])] + try: + data=sym_dict[datas[0]] + except Exception, e: + print 'can not find symbol %s'%(datas[0]) + raise e + if node['name'] == layer_name: + sym = sym_handle(data, node) + else: + if node['op'] == 'Convolution': + kernel = eval(node['param']['kernel']) + pad = eval(node['param']['pad']) + num_filter = int(node['param']['num_filter']) + name = node['name'] + sym = mx.symbol.Convolution(data=data, kernel=kernel, pad=pad, num_filter=num_filter, name=name) + elif node['op'] == 'Activation': + sym = mx.symbol.Activation(data=data, act_type=node['param']['act_type'], name=node['name']) + elif node['op'] == 'Pooling': + kernel = eval(node['param']['kernel']) + pad = eval(node['param']['pad']) + pool_type = node['param']['pool_type'] + stride = eval(node['param']['stride']) + sym = mx.symbol.Pooling(data=data, kernel=kernel, pad=pad, pool_type=pool_type, stride=stride, name=node['name']) + elif node['op'] == 'Dropout': + p = float(node['param']['p']) + sym = mx.symbol.Dropout(data=data, p=p, name=node['name']) + elif node['op'] == 'FullyConnected': + no_bias = True if node['param']['no_bias']=='True' else False + num_hidden = int(node['param']['num_hidden']) + sym = mx.symbol.FullyConnected(data=data, num_hidden=num_hidden, no_bias=no_bias, name=node['name']) + elif node['op'] == 'Flatten': + sym = mx.symbol.Flatten(data=data, name=node['name']) + elif node['op'] == 'SoftmaxOutput': + sym = mx.symbol.SoftmaxOutput(data=data, name='softmax') + res_sym = sym + elif node['op'] == 'Reshape': + target_shape = eval(node['param']['target_shape']) + sym = mx.symbol.Reshape(data=data, target_shape=target_shape) + res_sym = sym + else: + raise Exception("Invalid symbol") + if sym: + sym_dict[node['name']] = sym + + arg_params = copy.deepcopy(old_model.arg_params) + if layer_name: + arg_shapes, _, _ = res_sym.infer_shape(data=(1,3,224,224)) + arg_names = res_sym.list_arguments() + arg_shape_dic = dict(zip(arg_names, arg_shapes)) + try: + arg_handle(arg_shape_dic, arg_params) + except Exception, e: + raise Exception('Exception in arg_handle') + + new_model = mx.model.FeedForward( + symbol=res_sym, + ctx=old_model.ctx, + num_epoch=1, + epoch_size=old_model.epoch_size, + optimizer='sgd', + initializer=old_model.initializer, + numpy_batch_size=old_model.numpy_batch_size, + arg_params=arg_params, + aux_params=old_model.aux_params, + allow_extra_params=True, + begin_epoch=old_model.begin_epoch) + return new_model From 1e5e8c0e01dcf9f7ea27d6299a4cf3b0c44a8bd0 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 5 Jan 2016 07:02:57 +0800 Subject: [PATCH 28/32] set argument '--gpus' to '0' --- tools/accnn/acc_conv.py | 4 ++-- tools/accnn/acc_fc.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tools/accnn/acc_conv.py b/tools/accnn/acc_conv.py index 8f468def14fc..095e386beebc 100644 --- a/tools/accnn/acc_conv.py +++ b/tools/accnn/acc_conv.py @@ -67,11 +67,11 @@ def main(): if __name__ == '__main__': parser=argparse.ArgumentParser() parser.add_argument('-m', '--model', help='the model to speed up') - parser.add_argument('-g', '--gpus', default='0,1,2,3', help='the gpus to be used in ctx') + parser.add_argument('-g', '--gpus', default='0', help='the gpus to be used in ctx') parser.add_argument('--load-epoch',type=int,default=1) parser.add_argument('--layer') parser.add_argument('--K', type=int) parser.add_argument('--save-model') args = parser.parse_args() main() - \ No newline at end of file + diff --git a/tools/accnn/acc_fc.py b/tools/accnn/acc_fc.py index a7b7da163990..dcc255452b1d 100644 --- a/tools/accnn/acc_fc.py +++ b/tools/accnn/acc_fc.py @@ -48,7 +48,7 @@ def main(): if __name__ == '__main__': parser=argparse.ArgumentParser() parser.add_argument('-m', '--model', help='the model to speed up') - parser.add_argument('-g', '--gpus', default='0,1,2,3', help='the gpus to be used in ctx') + parser.add_argument('-g', '--gpus', default='0', help='the gpus to be used in ctx') parser.add_argument('--load-epoch',type=int,default=1) parser.add_argument('--layer') parser.add_argument('--K', type=int) From 5cd328c71a61acd5e74fb06d7259ecde08298a28 Mon Sep 17 00:00:00 2001 From: Boyuan Deng Date: Tue, 5 Jan 2016 01:53:57 +0100 Subject: [PATCH 29/32] [DOC] minor fix in installation guide --- doc/build.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/build.md b/doc/build.md index b3d8ff2559dd..b4564bef6fbf 100644 --- a/doc/build.md +++ b/doc/build.md @@ -35,7 +35,7 @@ Our goal is to build the shared library: The minimal building requirement is - A recent c++ compiler supporting C++ 11 such as `g++ >= 4.8` or `clang` -- A BLAS library, such as `libblas`, `libblas`, `openblas` `intel mkl` +- A BLAS library, such as `libblas`, `atlas`, `openblas` or `intel mkl` Optional libraries From 35e898e38d99a430781290c2f76a898eb8b7c14a Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Tue, 5 Jan 2016 02:13:37 -0700 Subject: [PATCH 30/32] [example] kaggle ndsb-1 --- example/kaggle-ndsb1/README.md | 34 ++++++++++ example/kaggle-ndsb1/gen_img_list.py | 43 +++++++++++++ example/kaggle-ndsb1/run_local.py | 96 ++++++++++++++++++++++++++++ 3 files changed, 173 insertions(+) create mode 100644 example/kaggle-ndsb1/README.md create mode 100644 example/kaggle-ndsb1/gen_img_list.py create mode 100644 example/kaggle-ndsb1/run_local.py diff --git a/example/kaggle-ndsb1/README.md b/example/kaggle-ndsb1/README.md new file mode 100644 index 000000000000..057c69c5d368 --- /dev/null +++ b/example/kaggle-ndsb1/README.md @@ -0,0 +1,34 @@ +Tutorial for Kaggle NDSB-1 +----- + +This is an MXNet example for Kaggle Nation Data Science Bowl 1. + +In this example we ignored submission part, only show local validation result. + +#### Step 1: Generate image list +- Prepare original data, in layout like +``` +--gen_img_list.py +--data/ + | + |--train/ + | | + | |--acantharia_protist/... + | |--.../ + |--sampleSubmission.csv +``` +- Run command ``` python gen_img_list.py train data/sampleSubmission.csv data/train/ train.lst``` to generate a full image list +- Run command ```sed -n '1, 20000p' train.lst > tr.lst``` to generate local train list +- Run command ```sed -n '20001p, 30337p' train.lst > va.lst``` to generate local validation list + + +#### Step 2: Generate Image Record (new shape with short edge = 48) +- Run command ```../../bin/im2rec tr.lst ./ tr.rec resize=48``` to generate training data record file +- Run command ```../../bin/im2rec va.lst ./ va.rec resize=48``` to generate validation data record file + +#### Step 3: Train Model +- Feel free to change hyper parameter in ```run_local.py``` +- Run ```python run_local.py``` to train the model +- Sample code result: Train-accuracy=60.1%, Validation-accuracy=62.1% + + diff --git a/example/kaggle-ndsb1/gen_img_list.py b/example/kaggle-ndsb1/gen_img_list.py new file mode 100644 index 000000000000..c88fb3c562e6 --- /dev/null +++ b/example/kaggle-ndsb1/gen_img_list.py @@ -0,0 +1,43 @@ +import csv +import os +import sys +import random + +if len(sys.argv) < 4: + print "Usage: gen_img_list.py train/test sample_submission.csv train_folder img.lst" + exit(1) + +random.seed(888) + +task = sys.argv[1] +fc = csv.reader(file(sys.argv[2])) +fi = sys.argv[3] +fo = csv.writer(open(sys.argv[4], "w"), delimiter='\t', lineterminator='\n') + +# make class map +head = fc.next() +head = head[1:] + +# make image list +img_lst = [] +cnt = 0 +if task == "train": + for i in xrange(len(head)): + path = fi + head[i] + lst = os.listdir(fi + head[i]) + for img in lst: + img_lst.append((cnt, i, path + '/' + img)) + cnt += 1 +else: + lst = os.listdir(fi) + for img in lst: + img_lst.append((cnt, 0, fi + img)) + cnt += 1 + +# shuffle +random.shuffle(img_lst) + +#wirte +for item in img_lst: + fo.writerow(item) + diff --git a/example/kaggle-ndsb1/run_local.py b/example/kaggle-ndsb1/run_local.py new file mode 100644 index 000000000000..172035ca443b --- /dev/null +++ b/example/kaggle-ndsb1/run_local.py @@ -0,0 +1,96 @@ +import mxnet as mx +import numpy as np +import logging + +# Example performance: +# INFO:root:Epoch[34] Train-accuracy=0.601388 +# INFO:root:Epoch[34] Validation-accuracy=0.620949 + +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +# running device +dev = mx.gpu() +# batch size and input shape +batch_size = 64 +data_shape = (3, 36, 36) +# training data info for learning rate reduction +num_examples = 20000 +epoch_size = num_examples / batch_size +lr_factor_epoch = 15 +# model saving parameter +model_prefix = "./models/sample_net" + +# train data iterator +train = mx.io.ImageRecordIter( + path_imgrec = "tr.rec", + mean_r = 128, + mean_g = 128, + mean_b = 128, + scale = 0.0078125, + max_aspect_ratio = 0.35, + data_shape = data_shape, + batch_size = batch_size, + rand_crop = True, + rand_mirror = True) + +# validate data iterator +val = mx.io.ImageRecordIter( + path_imgrec = "va.rec", + mean_r = 128, + mean_b = 128, + mean_g = 128, + scale = 0.0078125, + rand_crop = False, + rand_mirror = False, + data_shape = data_shape, + batch_size = batch_size) + +# network definition +# stage 1 +net = mx.sym.Variable("data") +net = mx.sym.Convolution(data=net, kernel=(5, 5), num_filter=32, pad=(2, 2)) +net = mx.sym.Activation(data=net, act_type="relu") +net = mx.sym.Convolution(data=net, kernel=(5, 5), num_filter=64, pad=(2, 2)) +net = mx.sym.Activation(data=net, act_type="relu") +net = mx.sym.Pooling(data=net, pool_type="max", kernel=(3, 3), stride=(2, 2)) +# stage 2 +net = mx.sym.Convolution(data=net, kernel=(3, 3), num_filter=64, pad=(1, 1)) +net = mx.sym.Activation(data=net, act_type="relu") +net = mx.sym.Convolution(data=net, kernel=(3, 3), num_filter=64, pad=(1, 1)) +net = mx.sym.Activation(data=net, act_type="relu") +net = mx.sym.Convolution(data=net, kernel=(3, 3), num_filter=128, pad=(1, 1)) +net = mx.sym.Activation(data=net, act_type="relu") +net = mx.sym.Pooling(data=net, pool_type="max", kernel=(3, 3), stride=(2, 2)) +# stage 3 +net = mx.sym.Convolution(data=net, kernel=(3, 3), num_filter=256, pad=(1, 1)) +net = mx.sym.Activation(data=net, act_type="relu") +net = mx.sym.Convolution(data=net, kernel=(3, 3), num_filter=256, pad=(1, 1)) +net = mx.sym.Activation(data=net, act_type="relu") +net = mx.sym.Pooling(data=net, pool_type="avg", kernel=(9, 9), stride=(1, 1)) +# stage 4 +net = mx.sym.Flatten(data=net) +net = mx.sym.Dropout(data=net, p=0.25) +net = mx.sym.FullyConnected(data=net, num_hidden=121) +net = mx.symbol.SoftmaxOutput(data=net, name='softmax') + +# Model parameter +# This model will reduce learning rate by factor 0.1 for every 15 epoch +model = mx.model.FeedForward( + ctx = dev, + symbol = net, + num_epoch = 35, + learning_rate = 0.01, + momentum = 0.9, + wd = 0.0001, + clip_gradient = 5, + lr_scheduler = mx.lr_scheduler.FactorScheduler(step=epoch_size * lr_factor_epoch, factor = 0.1), + initializer = mx.init.Xavier(factor_type="in", magnitude=2.34)) + +# fit the model +model.fit( + X = train, + eval_data = val, + batch_end_callback = mx.callback.Speedometer(batch_size, 50), + epoch_end_callback = mx.callback.do_checkpoint(model_prefix)) + From 9e59f7b05643fccce279c5e26f325fcaf1e39175 Mon Sep 17 00:00:00 2001 From: Lodewic van Twillert Date: Tue, 5 Jan 2016 10:56:09 +0100 Subject: [PATCH 31/32] Update build.md Added instructions how to install the GPU-enabled R package for Windows. --- doc/build.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/doc/build.md b/doc/build.md index b4564bef6fbf..60243578b5aa 100644 --- a/doc/build.md +++ b/doc/build.md @@ -239,6 +239,33 @@ Now you should have the R package as a tar.gz file and you can install it as a n R CMD INSTALL mxnet_0.5.tar.gz ``` + +To install the package using GPU on Windows without building the package from scratch. Note that you need a couple of programs installed already: +- You'll need the [CUDA Toolkit](https://developer.nvidia.com/cuda-toolkit). This depends on Visual Studio, and a free compatible version would be [Visual Studio Community 2013](https://www.visualstudio.com/en-us/news/vs2013-community-vs.aspx). For instructions and compatibility checks, read http://docs.nvidia.com/cuda/cuda-getting-started-guide-for-microsoft-windows/ . + +- You will also need to register as a developer at nvidia and download CUDNN V3, https://developer.nvidia.com/cudnn . + + +(1) Download the mxnet package as a ZIP from the Github repository https://github.com/dmlc/mxnet and unpack it. You will be editing the `/mxnet/R-package` folder. + +(2) Download the most recent GPU-enabled package from the [Releases tab](https://github.com/dmlc/mxnet/releases). Unzip this file so you have a folder `/nocudnn`. Note that this file and the folder you'll save it in will be used for future reference and not directly for installing the package. Only some files will be copied from it into the `R-package` folder. + +(Note: you now have 2 folders we're working with, possibly in different locations, that we'll reference with `R-package/` and `nocudnn/`.) + +(3) Download CUDNN V3 from https://developer.nvidia.com/cudnn. Unpack the .zip file and you'll see 3 folders, `/bin`, `/include`, `/lib`. Copy and replace these 3 folders into `nocudnn/3rdparty/cudnn/`, or unpack the .zip file there directly. + +(4) create the folder `R-package/inst/libs/x64`. We only support 64-bit operating system now, so you need the x64 folder; + +(5) put dll files in `R-package/inst/libs/x64`. + +The first dll file you need is `nocudnn/lib/libmxnet.dll`. The other dll files you need are the ones in all 4 subfolders of `nocudnn/3rdparty/`, for the `cudnn` and `openblas` you'll need to look in the `/bin` folders. There should be 11 dll files now in `R-package/inst/libs/x64`. + +(6) copy the folder `nocudnn/include/` to `R-package/inst/`. So now you should have a folder `R-package/inst/include/` with 3 subfolders. + +(7) `R CMD INSTALL R-package`. Make sure that R is added to your PATH in Environment Variables. Running the command `Where R` in Command Prompt should return the location. + +You might get an error similar to `Error: loading failed for 'i386'`. In that case, run `R CMD INSTALL --no-multiarch R-package`. + Note on Library Build: We isolate the library build with Rcpp end to maximize the portability From f127a693500f834209aef607ebd21e591334277e Mon Sep 17 00:00:00 2001 From: Lodewic van Twillert Date: Tue, 5 Jan 2016 16:25:46 +0100 Subject: [PATCH 32/32] Update build.md Some minor changes for numbering style ( (1) to 1.). And edited step 7. to always include --no-multiarch. --- doc/build.md | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/doc/build.md b/doc/build.md index 60243578b5aa..1d144949cb8d 100644 --- a/doc/build.md +++ b/doc/build.md @@ -246,25 +246,23 @@ To install the package using GPU on Windows without building the package from sc - You will also need to register as a developer at nvidia and download CUDNN V3, https://developer.nvidia.com/cudnn . -(1) Download the mxnet package as a ZIP from the Github repository https://github.com/dmlc/mxnet and unpack it. You will be editing the `/mxnet/R-package` folder. +1. Download the mxnet package as a ZIP from the Github repository https://github.com/dmlc/mxnet and unpack it. You will be editing the `/mxnet/R-package` folder. -(2) Download the most recent GPU-enabled package from the [Releases tab](https://github.com/dmlc/mxnet/releases). Unzip this file so you have a folder `/nocudnn`. Note that this file and the folder you'll save it in will be used for future reference and not directly for installing the package. Only some files will be copied from it into the `R-package` folder. +2. Download the most recent GPU-enabled package from the [Releases tab](https://github.com/dmlc/mxnet/releases). Unzip this file so you have a folder `/nocudnn`. Note that this file and the folder you'll save it in will be used for future reference and not directly for installing the package. Only some files will be copied from it into the `R-package` folder. (Note: you now have 2 folders we're working with, possibly in different locations, that we'll reference with `R-package/` and `nocudnn/`.) -(3) Download CUDNN V3 from https://developer.nvidia.com/cudnn. Unpack the .zip file and you'll see 3 folders, `/bin`, `/include`, `/lib`. Copy and replace these 3 folders into `nocudnn/3rdparty/cudnn/`, or unpack the .zip file there directly. +3. Download CUDNN V3 from https://developer.nvidia.com/cudnn. Unpack the .zip file and you'll see 3 folders, `/bin`, `/include`, `/lib`. Copy and replace these 3 folders into `nocudnn/3rdparty/cudnn/`, or unpack the .zip file there directly. -(4) create the folder `R-package/inst/libs/x64`. We only support 64-bit operating system now, so you need the x64 folder; +4. Create the folder `R-package/inst/libs/x64`. We only support 64-bit operating system now, so you need the x64 folder; -(5) put dll files in `R-package/inst/libs/x64`. +5. Put dll files in `R-package/inst/libs/x64`. The first dll file you need is `nocudnn/lib/libmxnet.dll`. The other dll files you need are the ones in all 4 subfolders of `nocudnn/3rdparty/`, for the `cudnn` and `openblas` you'll need to look in the `/bin` folders. There should be 11 dll files now in `R-package/inst/libs/x64`. -(6) copy the folder `nocudnn/include/` to `R-package/inst/`. So now you should have a folder `R-package/inst/include/` with 3 subfolders. +6. Copy the folder `nocudnn/include/` to `R-package/inst/`. So now you should have a folder `R-package/inst/include/` with 3 subfolders. -(7) `R CMD INSTALL R-package`. Make sure that R is added to your PATH in Environment Variables. Running the command `Where R` in Command Prompt should return the location. - -You might get an error similar to `Error: loading failed for 'i386'`. In that case, run `R CMD INSTALL --no-multiarch R-package`. +7. Run `R CMD INSTALL --no-multiarch R-package`. Make sure that R is added to your PATH in Environment Variables. Running the command `Where R` in Command Prompt should return the location. Note on Library Build: