From 18404b166a37095eee6844275155b681ec2b0b6c Mon Sep 17 00:00:00 2001 From: yajiedesign Date: Tue, 15 Dec 2015 11:11:44 +0800 Subject: [PATCH 01/15] fix compile error with visual studio --- include/mxnet/optimizer.h | 6 ++++++ src/optimizer/sgd-inl.h | 1 - 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/include/mxnet/optimizer.h b/include/mxnet/optimizer.h index fb15f2c04ec5..5b438078519c 100644 --- a/include/mxnet/optimizer.h +++ b/include/mxnet/optimizer.h @@ -17,9 +17,15 @@ #include "./base.h" #include "./resource.h" +#if DMLC_USE_CXX11 +#include +#endif + namespace mxnet { +#if !DMLC_USE_CXX11 class NDArray; +#endif class Optimizer { public: diff --git a/src/optimizer/sgd-inl.h b/src/optimizer/sgd-inl.h index 8edbe337a3ae..ce1785102221 100644 --- a/src/optimizer/sgd-inl.h +++ b/src/optimizer/sgd-inl.h @@ -101,7 +101,6 @@ void call_sgd_update_gpu(RunContext ctx, TBlob weight, const TBlob grad, #endif // MXNET_USE_CUDA #if DMLC_USE_CXX11 -#include class SGDOpt : public Optimizer { public: From de254591b33196cc50982aa51e8cd3343d6d0f6c Mon Sep 17 00:00:00 2001 From: luke metz Date: Tue, 15 Dec 2015 11:04:39 -0500 Subject: [PATCH 02/15] refactor model.py _train_multi_device --- python/mxnet/executor.py | 213 ++++++++++++++++++++++++++++++- python/mxnet/model.py | 266 ++++++++++++--------------------------- 2 files changed, 294 insertions(+), 185 deletions(-) diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 05a269d1e2db..3e357fb18974 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -1,5 +1,5 @@ # coding: utf-8 -# pylint: disable=invalid-name, protected-access, too-many-locals +# pylint: disable=invalid-name, protected-access, too-many-locals, too-many-arguments """Symbolic Executor component of MXNet.""" from __future__ import absolute_import @@ -8,6 +8,9 @@ from .base import mx_uint, NDArrayHandle, ExecutorHandle from .base import check_call, c_array, py_str from .ndarray import NDArray +from . import ndarray as nd +from .context import cpu +import logging class Executor(object): """ Executor is the actual executing object of MXNet.""" @@ -216,3 +219,211 @@ def debug_str(self): check_call(_LIB.MXExecutorPrint( self.handle, ctypes.byref(debug_str))) return py_str(debug_str.value) + +def _split_input_slice(batch_size, work_load_list): + """Get input slice from the input shape. + Parameters + ---------- + batch_size : int + The number of samples in a mini-batch. + work_load_list : list of float or int, optional + The list of work load for different devices, + in the same order as ctx + Returns + ------- + slices : list of slice + The split slices to get a specific slice. + Raises + ------ + ValueError + If there are two many splits such that some slice can be empty. + """ + total_work_load = sum(work_load_list) + batch_num_list = [round(work_load * batch_size / total_work_load) + for work_load in work_load_list] + batch_num_sum = sum(batch_num_list) + if batch_num_sum < batch_size: + batch_num_list[-1] += batch_size - batch_num_sum + slices = [] + end = 0 + for batch_num in batch_num_list: + begin = int(min((end, batch_size))) + end = int(min((begin + batch_num, batch_size))) + if begin >= end: + raise ValueError('Too many slices such that some splits are empty') + slices.append(slice(begin, end)) + return slices + +def _check_arguments(symbol): + """Check the argument names of symbol. + This function checks the duplication of arguments in Symbol. + The check is done for feedforward net for now. + Parameters + ---------- + symbol : Symbol + The network configuration + """ + arg_set = set() + arg_names = symbol.list_arguments() + for name in arg_names: + if name in arg_set: + raise ValueError(('Find duplicated argument name \"%s\", ' + + 'please make the weight name non-duplicated(using name arguments), ' + + 'arguments are %s') % (name, str(arg_names))) + arg_set.add(name) + + aux_set = set() + aux_names = symbol.list_auxiliary_states() + for name in aux_names: + if name in aux_set: + raise ValueError( + ('Find duplicated auxiliary param name \"%s\", ' + + 'please make the weight name non-duplicated(using name arguments), ' + + 'arguments are %s, auxiliary params are %s' + ) % (name, str(arg_names), str(aux_names))) + aux_set.add(name) + +def _load_general(data, targets): + """Load a list of arrays into a list of arrays specified by slices""" + for d_src, d_targets in zip(data, targets): + if isinstance(d_targets, nd.NDArray): + d_src.copyto(d_targets) + else: + for slice_idx, d_dst in d_targets: + d_src[slice_idx].copyto(d_dst) + +def _load_data(batch, targets): + """Load data into sliced arrays""" + _load_general(batch.data, targets) + +def _load_label(batch, targets): + """Load label into sliced arrays""" + _load_general(batch.label, targets) + +class DataParallelExecutorManager(object): + """ Helper class to manage multiple executors for data parallelism. + Parameters + ---------- + symbol : Symbol + output symbol + ctx : list of Context + devices to run on + param_names: list of str + Name of all trainable parameters of the network. + arg_names: list of str + Name of all arguments of the network. + aux_names: list of str + Name of all auxiliary states of the network. + train_data : DataIter + Training data iterator. + work_load_list : list of float or int, optional + The list of work load for different devices, + in the same order as ctx + logger : logging logger + When not specified, default logger will be used. + """ + def __init__(self, symbol, ctx, train_data, + param_names, arg_names, aux_names, + work_load_list=None, logger=None): + if logger is None: + logger = logging + # preparation + num_device = len(ctx) + logger.info('Start training with %s', str(ctx)) + + # make sure the architecture is valid + _check_arguments(symbol) + + if work_load_list is None: + work_load_list = [1] * num_device + assert isinstance(work_load_list, list) and len(work_load_list) == num_device, \ + "Invalid settings for work load. " + + slices = _split_input_slice(train_data.batch_size, work_load_list) + self.slices = slices + + self.train_execs = [] + for i in range(len(ctx)): + data_shapes = {k: tuple([slices[i].stop-slices[i].start] + list(v[1:])) + for k, v in train_data.provide_data} + train_exec = symbol.simple_bind(ctx[i], 'write', **data_shapes) + self.train_execs.append(train_exec) + + # data structure + self.data_names = [x[0] for x in train_data.provide_data] + self.label_names = [x[0] for x in train_data.provide_label] + self.aux_names = aux_names + + self.data_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)] + for name in self.data_names] + self.label_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(self.train_execs)] + for name in self.label_names] + + self.param_idx = [i for i in range(len(arg_names)) if arg_names[i] in param_names] + self.param_names = [arg_names[i] for i in self.param_idx] + self.param_arrays = [[e.arg_arrays[i] for e in self.train_execs] + for i in self.param_idx] + self.grad_arrays = [[e.grad_arrays[i] for e in self.train_execs] + for i in self.param_idx] + + self.aux_arrays = [[e.aux_arrays[i] for e in self.train_execs] + for i in range(len(aux_names))] + + batch_size = train_data.batch_size + + output_shapes = [tuple([batch_size]+list(x.shape[1:])) for x in self.train_execs[0].outputs] + self.cpu_output_arrays = [nd.zeros(s) for s in output_shapes] + + def install_monitor(self, monitor): + """ Install monitor on all executors """ + for train_exec in self.train_execs: + monitor.install(train_exec) + + def set_params(self, arg_params, aux_params): + """ set parameter and aux values + Parameters + ---------- + arg_params : list of NDArray + source parameter arrays + aux_params : list of NDArray + source aux arrays + """ + + for texec in self.train_execs: + texec.copy_params_from(arg_params, aux_params) + + def copy_to(self, arg_params, aux_params): + """ Copy data from each executor to `arg_params` and `aux_params` + Parameters + ---------- + arg_params : list of NDArray + target parameter arrays + aux_params : list of NDArray + target aux arrays + Notes + ----- + - This function will inplace update the NDArrays in arg_params and aux_params. + """ + for name, block in zip(self.param_names, self.param_arrays): + weight = sum(w.copyto(cpu()) for w in block) / len(block) + weight.copyto(arg_params[name]) + for name, block in zip(self.aux_names, self.aux_arrays): + weight = sum(w.copyto(cpu()) for w in block) / len(block) + weight.copyto(aux_params[name]) + + def load_data_batch(self, data_batch): + """ load data and labels into arrays """ + _load_data(data_batch, self.data_arrays) + _load_label(data_batch, self.label_arrays) + + def forward(self, is_train=False): + """ Perform a forward pass on each executor """ + for texec, islice in zip(self.train_execs, self.slices): + texec.forward(is_train=is_train) + for cpu_out, dev_out in zip(self.cpu_output_arrays, texec.outputs): + dev_out.copyto(cpu_out[islice]) + + def backward(self): + """ Perform a backward pass on each executor """ + for texec in self.train_execs: + texec.backward() diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 70774321fdcc..5d9288406f7a 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -16,6 +16,7 @@ from .initializer import Uniform from collections import namedtuple from .optimizer import get_updater +from .executor import DataParallelExecutorManager, _check_arguments, _load_data BASE_ESTIMATOR = object @@ -31,86 +32,6 @@ 'nbatch', 'eval_metric']) -def _load_general(data, targets): - """Load a list of arrays into a list of arrays specified by slices""" - for d_src, d_targets in zip(data, targets): - if isinstance(d_targets, nd.NDArray): - d_src.copyto(d_targets) - else: - for slice_idx, d_dst in d_targets: - d_src[slice_idx].copyto(d_dst) -def _load_data(batch, targets): - """Load data into sliced arrays""" - _load_general(batch.data, targets) -def _load_label(batch, targets): - """Load label into sliced arrays""" - _load_general(batch.label, targets) - - -def _check_arguments(symbol): - """Check the argument names of symbol. - This function checks the duplication of arguments in Symbol. - The check is done for feedforward net for now. - Parameters - ---------- - symbol : Symbol - The network configuration - """ - arg_set = set() - arg_names = symbol.list_arguments() - for name in arg_names: - if name in arg_set: - raise ValueError(('Find duplicated argument name \"%s\", ' + - 'please make the weight name non-duplicated(using name arguments), ' + - 'arguments are %s') % (name, str(arg_names))) - arg_set.add(name) - - aux_set = set() - aux_names = symbol.list_auxiliary_states() - for name in aux_names: - if name in aux_set: - raise ValueError( - ('Find duplicated auxiliary param name \"%s\", ' + - 'please make the weight name non-duplicated(using name arguments), ' + - 'arguments are %s, auxiliary params are %s' - ) % (name, str(arg_names), str(aux_names))) - aux_set.add(name) - - -def _split_input_slice(batch_size, work_load_list): - """Get input slice from the input shape. - Parameters - ---------- - batch_size : int - The number of samples in a mini-batch. - work_load_list : list of float or int, optional - The list of work load for different devices, - in the same order as ctx - Returns - ------- - slices : list of slice - The split slices to get a specific slice. - Raises - ------ - ValueError - If there are two many splits such that some slice can be empty. - """ - total_work_load = sum(work_load_list) - batch_num_list = [round(work_load * batch_size / total_work_load) - for work_load in work_load_list] - batch_num_sum = sum(batch_num_list) - if batch_num_sum < batch_size: - batch_num_list[-1] += batch_size - batch_num_sum - slices = [] - end = 0 - for batch_num in batch_num_list: - begin = int(min((end, batch_size))) - end = int(min((begin + batch_num, batch_size))) - if begin >= end: - raise ValueError('Too many slices such that some splits are empty') - slices.append(slice(begin, end)) - return slices - def _create_kvstore(kvstore, num_device, arg_params): """Create kvstore This function select and create a proper kvstore if given the kvstore type @@ -153,6 +74,46 @@ def _create_kvstore(kvstore, num_device, arg_params): return (kv, update_on_kvstore) +def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, + update_on_kvstore): + """ Initialize kvstore""" + for idx in range(len(param_arrays)): + param_on_devs = param_arrays[idx] + kvstore.init(idx, arg_params[param_names[idx]]) + + if update_on_kvstore: + kvstore.pull(idx, param_on_devs, priority=-idx) + +def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore): + """ Perform update of param_arrays from grad_arrays on kvstore.""" + for index, pair in enumerate(zip(param_arrays, grad_arrays)): + arg_list, grad_list = pair + if grad_list[0] is None: + continue + # push gradient, priority is negative index + kvstore.push(index, grad_list, priority=-index) + # pull back the weights + kvstore.pull(index, arg_list, priority=-index) + +def _update_params(param_arrays, grad_arrays, updater, num_device, + kvstore=None): + """ Perform update of param_arrays from grad_arrays not on kvstore.""" + for index, pair in enumerate(zip(param_arrays, grad_arrays)): + arg_list, grad_list = pair + if grad_list[0] is None: + continue + if kvstore: + # push gradient, priority is negative index + kvstore.push(index, grad_list, priority=-index) + # pull back the sum gradients, to the same locations. + kvstore.pull(index, grad_list, priority=-index) + for k, p in enumerate(zip(arg_list, grad_list)): + # faked an index here, to make optimizer create diff + # state for the same index but on diff devs, TODO(mli) + # use a better solution latter + w, g = p + updater(index*num_device+k, g, w) + def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, arg_params, aux_params, begin_epoch, end_epoch, epoch_size, optimizer, @@ -216,66 +177,31 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, """ if logger is None: logger = logging - # preparation - num_device = len(ctx) - logging.info('Start training with %s', str(ctx)) - - # make sure the architecture is valid - _check_arguments(symbol) - - if work_load_list is None: - work_load_list = [1] * num_device - assert isinstance(work_load_list, list) and len(work_load_list) == num_device, \ - "Invalid settings for work load. " - slices = _split_input_slice(train_data.batch_size, work_load_list) - train_execs = [] - for i in range(len(ctx)): - data_shapes = {k: tuple([slices[i].stop-slices[i].start] + list(v[1:])) - for k, v in train_data.provide_data} - train_exec = symbol.simple_bind(ctx[i], 'write', **data_shapes) - if monitor: - monitor.install(train_exec) - train_execs.append(train_exec) - - # data structure - data_names = [x[0] for x in train_data.provide_data] - label_names = [x[0] for x in train_data.provide_label] - - data_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(train_execs)] - for name in data_names] - label_arrays = [[(slices[i], e.arg_dict[name]) for i, e in enumerate(train_execs)] - for name in label_names] - - param_idx = [i for i in range(len(arg_names)) if arg_names[i] in param_names] - param_names = [arg_names[i] for i in param_idx] - param_arrays = [[e.arg_arrays[i] for e in train_execs] for i in param_idx] - grad_arrays = [[e.grad_arrays[i] for e in train_execs] for i in param_idx] - aux_arrays = [[e.aux_arrays[i] for e in train_execs] for i in range(len(aux_names))] - - for texec in train_execs: - texec.copy_params_from(arg_params, aux_params) + executor_manager = DataParallelExecutorManager(symbol=symbol, + ctx=ctx, + train_data=train_data, + param_names=param_names, + arg_names=arg_names, + aux_names=aux_names, + work_load_list=work_load_list, + logger=logger) + if monitor: + executor_manager.install_monitor(monitor) + + executor_manager.set_params(arg_params, aux_params) if not update_on_kvstore: updater = get_updater(optimizer) - # init kvstore if kvstore: - # init optimizer - if update_on_kvstore: - kvstore.set_optimizer(optimizer) - - # init kv - for idx in range(len(param_arrays)): - param_on_devs = param_arrays[idx] - kvstore.init(idx, arg_params[param_names[idx]]) - - if update_on_kvstore: - kvstore.pull(idx, param_on_devs, priority=-idx) + _initialize_kvstore(kvstore=kvstore, + param_arrays=executor_manager.param_arrays, + arg_params=arg_params, + param_names=executor_manager.param_names, + update_on_kvstore=update_on_kvstore) - batch_size = train_data.batch_size - - output_shapes = [tuple([batch_size]+list(x.shape[1:])) for x in train_execs[0].outputs] - cpu_output_arrays = [nd.zeros(s) for s in output_shapes] + if update_on_kvstore: + kvstore.set_optimizer(optimizer) # Now start training for epoch in range(begin_epoch, end_epoch): @@ -288,48 +214,31 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, while True: do_reset = True for data_batch in train_data: - _load_data(data_batch, data_arrays) - _load_label(data_batch, label_arrays) + + executor_manager.load_data_batch(data_batch) if monitor is not None: monitor.tic() - # forward backward pass - for texec, islice in zip(train_execs, slices): - texec.forward(is_train=True) - for cpu_out, dev_out in zip(cpu_output_arrays, texec.outputs): - dev_out.copyto(cpu_out[islice]) - #texec.outputs[0].copyto(out_cpu_array[islice]) - for texec in train_execs: - texec.backward() - - # update the parameters - for index, pair in enumerate(zip(param_arrays, grad_arrays)): - arg_list, grad_list = pair - if grad_list[0] is None: - continue - # Gradient synchronization - if kvstore: - # push gradient, priority is negative index - kvstore.push(index, grad_list, priority=-index) - if update_on_kvstore: - # pull back the weights - kvstore.pull(index, arg_list, priority=-index) - else: - # pull back the sum gradients, to the same locations. - kvstore.pull(index, grad_list, priority=-index) - if not update_on_kvstore: - for k, p in enumerate(zip(arg_list, grad_list)): - # faked an index here, to make optimizer create diff - # state for the same index but on diff devs, TODO(mli) - # use a better solution latter - w, g = p - updater(index*num_device+k, g, w) + + executor_manager.forward(is_train=True) + executor_manager.backward() + + if update_on_kvstore: + _update_params_on_kvstore(executor_manager.param_arrays, + executor_manager.grad_arrays, + kvstore) + else: + _update_params(executor_manager.param_arrays, + executor_manager.grad_arrays, + updater=updater, + num_device=len(ctx), + kvstore=kvstore) if monitor is not None: monitor.toc_print() # evaluate at end, so out_cpu_array can lazy copy - eval_metric.update(data_batch.label, cpu_output_arrays) + eval_metric.update(data_batch.label, executor_manager.cpu_output_arrays) nbatch += 1 # batch callback (for print purpose) @@ -364,26 +273,15 @@ def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names, eval_metric.reset() eval_data.reset() for eval_batch in eval_data: - _load_data(eval_batch, data_arrays) - _load_label(eval_batch, label_arrays) - - # forward pass - for texec, islice in zip(train_execs, slices): - texec.forward(is_train=False) - for cpu_out, dev_out in zip(cpu_output_arrays, texec.outputs): - dev_out.copyto(cpu_out[islice]) - eval_metric.update(eval_batch.label, cpu_output_arrays) + executor_manager.load_data_batch(eval_batch) + executor_manager.forward(is_train=False) + eval_metric.update(eval_batch.label, executor_manager.cpu_output_arrays) + name, value = eval_metric.get() logger.info('Epoch[%d] Validation-%s=%f', epoch, name, value) if epoch_end_callback or epoch + 1 == end_epoch: - # copy data back to cpu - for name, block in zip(param_names, param_arrays): - weight = sum(w.copyto(cpu()) for w in block) / len(block) - weight.copyto(arg_params[name]) - for name, block in zip(aux_names, aux_arrays): - weight = sum(w.copyto(cpu()) for w in block) / len(block) - weight.copyto(aux_params[name]) + executor_manager.copy_to(arg_params, aux_params) if epoch_end_callback != None: if isinstance(epoch_end_callback, list): From fee8175f0945233373c7866e6c242311c3859f6c Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 16 Dec 2015 08:01:00 +0800 Subject: [PATCH 03/15] Update build.md Add introduction for building with Intel MKL --- doc/build.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/build.md b/doc/build.md index 95be4372ea65..8df3d9d648bf 100644 --- a/doc/build.md +++ b/doc/build.md @@ -126,6 +126,9 @@ various distributed filesystem such as HDFS/Amazon S3/... - First copy [make/config.mk](../make/config.mk) to the project root, on which any local modification will be ignored by git, then modify the according flags. +#### Building with Intel MKL Support +First, `source /path/to/intel/bin/compilervars.sh` to automatically set environment variables. Then, edit [make/config.mk](../make/config.mk), let `USE_BLAS = mkl`. `USE_INTEL_PATH = NONE` is usually not necessary to be modified. + ## Python Package Installation From 087e1980900f2f144f5dc06467a8957cf43c842a Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 16 Dec 2015 09:49:13 +0800 Subject: [PATCH 04/15] Update CONTRIBUTORS.md --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 38d06e5e7e07..31f1356cc0a4 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -89,3 +89,4 @@ List of Contributors * [Ye Zhou](https://github.com/zhouye) * [Zhang Chen](https://github.com/zhangchen-qinyinghua) * [Xianliang Wang](https://github.com/wangxianliang) +* [Junru Shao](https://github.com/yzgysjr) From 4746a3ff9717002f6bfafbbe74a577eca5b85c1e Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Wed, 16 Dec 2015 13:50:02 -0700 Subject: [PATCH 05/15] add inception-v3 symbol --- .../symbol_inception-v3.py | 174 ++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 example/image-classification/symbol_inception-v3.py diff --git a/example/image-classification/symbol_inception-v3.py b/example/image-classification/symbol_inception-v3.py new file mode 100644 index 000000000000..99075ea61c54 --- /dev/null +++ b/example/image-classification/symbol_inception-v3.py @@ -0,0 +1,174 @@ +""" + +Inception V3, suitable for images with around 299 x 299 + +Reference: + +Szegedy, Christian, et al. "Rethinking the Inception Architecture for Computer Vision." arXiv preprint arXiv:1512.00567 (2015). + +""" + +import find_mxnet +import mxnet as mx + + +def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''): + conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix)) + bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=True) + act = mx.sym.Activation(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix)) + return act + + +def Inception7A(data, + num_1x1, + num_3x3_red, num_3x3_1, num_3x3_2, + num_5x5_red, num_5x5, + pool, proj, + name): + tower_1x1 = Conv(data, num_1x1, name=('%s_conv' % name)) + tower_5x5 = Conv(data, num_5x5_red, name=('%s_tower' % name), suffix='_conv') + tower_5x5 = Conv(tower_5x5, num_5x5, kernel=(5, 5), pad=(2, 2), name=('%s_tower' % name), suffix='_conv_1') + tower_3x3 = Conv(data, num_3x3_red, name=('%s_tower_1' % name), suffix='_conv') + tower_3x3 = Conv(tower_3x3, num_3x3_1, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1') + tower_3x3 = Conv(tower_3x3, num_3x3_2, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_2') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = Conv(pooling, proj, name=('%s_tower_2' % name), suffix='_conv') + concat = mx.sym.Concat(*[tower_1x1, tower_5x5, tower_3x3, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +# First Downsample +def Inception7B(data, + num_3x3, + num_d3x3_red, num_d3x3_1, num_d3x3_2, + pool, + name): + tower_3x3 = Conv(data, num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_conv' % name)) + tower_d3x3 = Conv(data, num_d3x3_red, name=('%s_tower' % name), suffix='_conv') + tower_d3x3 = Conv(tower_d3x3, num_d3x3_1, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_tower' % name), suffix='_conv_1') + tower_d3x3 = Conv(tower_d3x3, num_d3x3_2, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_2') + pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0,0), pool_type="max", name=('max_pool_%s_pool' % name)) + concat = mx.sym.Concat(*[tower_3x3, tower_d3x3, pooling], name='ch_concat_%s_chconcat' % name) + return concat + +def Inception7C(data, + num_1x1, + num_d7_red, num_d7_1, num_d7_2, + num_q7_red, num_q7_1, num_q7_2, num_q7_3, num_q7_4, + pool, proj, + name): + tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name)) + tower_d7 = Conv(data=data, num_filter=num_d7_red, name=('%s_tower' % name), suffix='_conv') + tower_d7 = Conv(data=tower_d7, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower' % name), suffix='_conv_1') + tower_d7 = Conv(data=tower_d7, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower' % name), suffix='_conv_2') + tower_q7 = Conv(data=data, num_filter=num_q7_red, name=('%s_tower_1' % name), suffix='_conv') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_1, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_1') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_2, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_2') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_3, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_3') + tower_q7 = Conv(data=tower_q7, num_filter=num_q7_4, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_4') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv') + # concat + concat = mx.sym.Concat(*[tower_1x1, tower_d7, tower_q7, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +def Inception7D(data, + num_3x3_red, num_3x3, + num_d7_3x3_red, num_d7_1, num_d7_2, num_d7_3x3, + pool, + name): + tower_3x3 = Conv(data=data, num_filter=num_3x3_red, name=('%s_tower' % name), suffix='_conv') + tower_3x3 = Conv(data=tower_3x3, num_filter=num_3x3, kernel=(3, 3), pad=(0,0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_1') + tower_d7_3x3 = Conv(data=data, num_filter=num_d7_3x3_red, name=('%s_tower_1' % name), suffix='_conv') + tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_1') + tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_2') + tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_3x3, kernel=(3, 3), stride=(2, 2), name=('%s_tower_1' % name), suffix='_conv_3') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + # concat + concat = mx.sym.Concat(*[tower_3x3, tower_d7_3x3, pooling], name='ch_concat_%s_chconcat' % name) + return concat + +def Inception7E(data, + num_1x1, + num_d3_red, num_d3_1, num_d3_2, + num_3x3_d3_red, num_3x3, num_3x3_d3_1, num_3x3_d3_2, + pool, proj, + name): + tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name)) + tower_d3 = Conv(data=data, num_filter=num_d3_red, name=('%s_tower' % name), suffix='_conv') + tower_d3_a = Conv(data=tower_d3, num_filter=num_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower' % name), suffix='_mixed_conv') + tower_d3_b = Conv(data=tower_d3, num_filter=num_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower' % name), suffix='_mixed_conv_1') + tower_3x3_d3 = Conv(data=data, num_filter=num_3x3_d3_red, name=('%s_tower_1' % name), suffix='_conv') + tower_3x3_d3 = Conv(data=tower_3x3_d3, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1') + tower_3x3_d3_a = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower_1' % name), suffix='_mixed_conv') + tower_3x3_d3_b = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower_1' % name), suffix='_mixed_conv_1') + pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) + cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv') + # concat + concat = mx.sym.Concat(*[tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj], name='ch_concat_%s_chconcat' % name) + return concat + +# In[49]: + +def get_symbol(num_classes=1000): + data = mx.symbol.Variable(name="data") + # stage 1 + conv = Conv(data, 32, kernel=(3, 3), stride=(2, 2), name="conv") + conv_1 = Conv(conv, 32, kernel=(3, 3), name="conv_1") + conv_2 = Conv(conv_1, 64, kernel=(3, 3), pad=(1, 1), name="conv_2") + pool = mx.sym.Pooling(data=conv_2, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool") + # stage 2 + conv_3 = Conv(pool, 80, kernel=(1, 1), name="conv_3") + conv_4 = Conv(conv_3, 192, kernel=(3, 3), name="conv_4") + pool1 = mx.sym.Pooling(data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool1") + # stage 3 + in3a = Inception7A(pool1, 64, + 64, 96, 96, + 48, 64, + "avg", 32, "mixed") + in3b = Inception7A(in3a, 64, + 64, 96, 96, + 48, 64, + "avg", 64, "mixed_1") + in3c = Inception7A(in3b, 64, + 64, 96, 96, + 48, 64, + "avg", 64, "mixed_2") + in3d = Inception7B(in3c, 384, + 64, 96, 96, + "max", "mixed_3") + # stage 4 + in4a = Inception7C(in3d, 192, + 128, 128, 192, + 128, 128, 128, 128, 192, + "avg", 192, "mixed_4") + in4b = Inception7C(in4a, 192, + 160, 160, 192, + 160, 160, 160, 160, 192, + "avg", 192, "mixed_5") + in4c = Inception7C(in4b, 192, + 160, 160, 192, + 160, 160, 160, 160, 192, + "avg", 192, "mixed_6") + in4d = Inception7C(in4c, 192, + 192, 192, 192, + 192, 192, 192, 192, 192, + "avg", 192, "mixed_7") + in4e = Inception7D(in4d, 192, 320, + 192, 192, 192, 192, + "max", "mixed_8") + # stage 5 + in5a = Inception7E(in4e, 320, + 384, 384, 384, + 448, 384, 384, 384, + "avg", 192, "mixed_9") + in5b = Inception7E(in5a, 320, + 384, 384, 384, + 448, 384, 384, 384, + "max", 192, "mixed_10") + # pool + pool = mx.sym.Pooling(data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", name="global_pool") + flatten = mx.sym.Flatten(data=pool, name="flatten") + fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1') + softmax = mx.symbol.SoftmaxOutput(data=fc1, name='softmax') + return softmax + From ff4f0194a3240f97e6836b7ed5ccc742cb7e7af4 Mon Sep 17 00:00:00 2001 From: qiaohaijun Date: Thu, 17 Dec 2015 18:41:22 +0800 Subject: [PATCH 06/15] fix the chunk_size fix the chunk_size so that doesn't lose some data --- tools/make_list.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/make_list.py b/tools/make_list.py index 926902807a54..8a042172ca4f 100644 --- a/tools/make_list.py +++ b/tools/make_list.py @@ -34,7 +34,7 @@ def make_list(prefix_out, root, recursive, exts, num_chunks, train_ratio): image_list = list_image(root, recursive, exts) random.shuffle(image_list) N = len(image_list) - chunk_size = N/num_chunks + chunk_size = (N+num_chunks-1)/num_chunks for i in xrange(num_chunks): chunk = image_list[i*chunk_size:(i+1)*chunk_size] if num_chunks > 1: @@ -70,4 +70,4 @@ def main(): args.exts, args.chunks, args.train_ratio) if __name__ == '__main__': - main() \ No newline at end of file + main() From 3f28d18ea31d2b5097b34e11c7924f4e34ecf294 Mon Sep 17 00:00:00 2001 From: skylook Date: Thu, 10 Dec 2015 17:52:52 +0800 Subject: [PATCH 07/15] Add image classification predict example for C++ --- example/cpp/image-classification/Makefile | 31 +++ example/cpp/image-classification/README.md | 64 +++++ .../image-classification-predict.cc | 237 ++++++++++++++++++ 3 files changed, 332 insertions(+) create mode 100644 example/cpp/image-classification/Makefile create mode 100644 example/cpp/image-classification/README.md create mode 100644 example/cpp/image-classification/image-classification-predict.cc diff --git a/example/cpp/image-classification/Makefile b/example/cpp/image-classification/Makefile new file mode 100644 index 000000000000..c1faea7e114b --- /dev/null +++ b/example/cpp/image-classification/Makefile @@ -0,0 +1,31 @@ +# Special thanks to https://github.com/pertusa for the Makefile +CFLAGS=-std=c++11 -Wno-unknown-pragmas -Wall + +# Added for openblas +# export OPENBLAS_ROOT=/usr/local/opt/openblas + +# CFLAGS+= -I${OPENBLAS_ROOT}/include +# LDFLAGS=-L${OPENBLAS_ROOT}/lib -lopenblas + +# Added for opencv +CFLAGS+= `pkg-config --cflags opencv` +LDFLAGS+=`pkg-config --libs opencv` + +# Added for mxnet +export MXNET_ROOT=`pwd`/../../../../mxnet + +CFLAGS+= -I$(MXNET_ROOT)/include +LDFLAGS+=$(MXNET_ROOT)/lib/libmxnet.so + +image-classification-predict: image-classification-predict.o + g++ -O3 -o image-classification-predict image-classification-predict.o $(LDFLAGS) + +image-classification-predict.o: image-classification-predict.cc + g++ -O3 -c image-classification-predict.cc ${CFLAGS} + +clean: + rm image-classification-predict + rm -f *.d *.o + +lint: + python ../../../dmlc-core/scripts/lint.py mxnet "cpp" ./ diff --git a/example/cpp/image-classification/README.md b/example/cpp/image-classification/README.md new file mode 100644 index 000000000000..71723dd30309 --- /dev/null +++ b/example/cpp/image-classification/README.md @@ -0,0 +1,64 @@ +# Image Classification Example of C++ +This is a simple predictor which shows how to use c api for image classfication. + +It uses opencv for image reading + +# How to Use + +## Build +1. Edit image-classification-predict.cc file, change the following lines to your model paths: + ```bash + // Models path for your model, you have to modify it + BufferFile json_data("model/Inception/Inception_BN-symbol.json"); + BufferFile param_data("model/Inception/Inception_BN-0039.params"); + ``` + +2. Edit synset file path if you have it: + ```bash + // Synset path for your model, you have to modify it + std::vector synset = LoadSynset("model/Inception/synset.txt"); + ``` + +3. You may also want to change the image size and channels: + ```bash + // Image size and channels + int width = 224; + int height = 224; + int channels = 3; + ``` + +4. Simply just use our Makefile to build: + ```bash + make + ``` + +## Usage +* Run: + ```bash + ./image-classification-predict apple.jpg + ``` +The only parameter is the path of the test image. + +## Tips +* The model used in the sample can be downloaded here: +http://pan.baidu.com/s/1sjXKrqX + +* If you donot run it in the mxnet root path, maybe you will need to copy lib folder here. + +# Author +* **Xiao Liu** + +* E-mail: liuxiao@foxmail.com + +* Homepage: [www.liuxiao.org](http://www.liuxiao.org/) + +# Thanks +* pertusa (for Makefile and image reading check) + +* caprice-j (for reading function) + +* sofiawu (for sample model) + +* piiswrong and tqchen (for useful coding suggestions) + + diff --git a/example/cpp/image-classification/image-classification-predict.cc b/example/cpp/image-classification/image-classification-predict.cc new file mode 100644 index 000000000000..246b17b86428 --- /dev/null +++ b/example/cpp/image-classification/image-classification-predict.cc @@ -0,0 +1,237 @@ +/*! + * Copyright (c) 2015 by Xiao Liu, pertusa, caprice-j + * \file image_classification-predict.cpp + * \brief C++ predict example of mxnet + */ + +// +// File: image-classification-predict.cpp +// This is a simple predictor which shows +// how to use c api for image classfication +// It uses opencv for image reading +// Created by liuxiao on 12/9/15. +// Thanks to : pertusa, caprice-j, sofiawu, tqchen, piiswrong +// Home Page: www.liuxiao.org +// E-mail: liuxiao@foxmail.com +// + +#include + +// Path for c_predict_api +#include + +#include +#include +#include + +#include +#include +#include +#include + +// Read file to buffer +class BufferFile { + public : + std::string file_path_; + int length_; + char* buffer_; + + explicit BufferFile(std::string file_path) + :file_path_(file_path) { + + std::ifstream ifs(file_path.c_str(), std::ios::in | std::ios::binary); + if (!ifs) { + std::cerr << "Can't open the file. Please check " << file_path << ". \n"; + assert(false); + } + + ifs.seekg(0, std::ios::end); + length_ = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + std::cout << file_path.c_str() << " ... "<< length_ << " bytes\n"; + + buffer_ = new char[sizeof(char) * length_]; + ifs.read(buffer_, length_); + ifs.close(); + } + + int GetLength() { + return length_; + } + char* GetBuffer() { + return buffer_; + } + + ~BufferFile() { + delete[] buffer_; + buffer_ = NULL; + } +}; + +void GetMeanFile(const std::string image_file, mx_float* image_data, + const int channels, const cv::Size resize_size) { + // Read all kinds of file into a BGR color 3 channels image + cv::Mat im_ori = cv::imread(image_file, 1); + + if (im_ori.empty()) { + std::cerr << "Can't open the image. Please check " << image_file << ". \n"; + assert(false); + } + + cv::Mat im; + + resize(im_ori, im, resize_size); + + // Better to be read from a mean.nb file + float mean = 117.0; + + int size = im.rows * im.cols * 3; + + mx_float* ptr_image_r = image_data; + mx_float* ptr_image_g = image_data + size / 3; + mx_float* ptr_image_b = image_data + size / 3 * 2; + + for (int i = 0; i < im.rows; i++) { + uchar* data = im.ptr(i); + + for (int j = 0; j < im.cols; j++) { + mx_float b = static_cast(*data++) - mean; + mx_float g = static_cast(*data++) - mean; + mx_float r = static_cast(*data++) - mean; + + *ptr_image_r++ = r; + *ptr_image_g++ = g; + *ptr_image_b++ = b; + } + } +} + +// LoadSynsets +// Code from : https://github.com/pertusa/mxnet_predict_cc/blob/master/mxnet_predict.cc +std::vector LoadSynset(const char *filename) { + std::ifstream fi(filename); + + if ( !fi.is_open() ) { + std::cerr << "Error opening file " << filename << std::endl; + assert(false); + } + + std::vector output; + + std::string synset, lemma; + while ( fi >> synset ) { + getline(fi, lemma); + output.push_back(lemma); + } + + fi.close(); + + return output; +} + +void PrintOutputResult(const std::vector& data, const std::vector& synset) { + if (data.size() != synset.size()) { + std::cerr << "Result data and synset size does not match!" << std::endl; + } + + float best_accuracy = 0.0; + int best_idx = 0; + + for ( int i = 0; i < static_cast(data.size()); i++ ) { + printf("Accuracy[%d] = %.8f\n", i, data[i]); + + if ( data[i] > best_accuracy ) { + best_accuracy = data[i]; + best_idx = i; + } + } + + printf("Best Result: [%s] id = %d, accuracy = %.8f\n", + synset[best_idx].c_str(), best_idx, best_accuracy); +} + +int main(int argc, char* argv[]) { + if (argc < 2) { + std::cout << "No test image here." << std::endl + << "Usage: ./image-classification-predict apple.jpg" << std::endl; + return 0; + } + + std::string test_file; + test_file = std::string(argv[1]); + + // Models path for your model, you have to modify it + BufferFile json_data("model/Inception/Inception_BN-symbol.json"); + BufferFile param_data("model/Inception/Inception_BN-0039.params"); + + // Parameters + int dev_type = 1; // 1: cpu, 2: gpu + int dev_id = 0; // arbitrary. + mx_uint num_input_nodes = 1; // 1 for feedforward + const char* input_key[1] = {"data"}; + const char** input_keys = input_key; + + // Image size and channels + int width = 224; + int height = 224; + int channels = 3; + + const mx_uint input_shape_indptr[2] = { 0, 4 }; + // ( trained_width, trained_height, channel, num) + const mx_uint input_shape_data[4] = { 1, + static_cast(channels), + static_cast(width), + static_cast(height) }; + PredictorHandle out = 0; // alias for void * + + //-- Create Predictor + MXPredCreate((const char*)json_data.GetBuffer(), + (const char*)param_data.GetBuffer(), + static_cast(param_data.GetLength()), + dev_type, + dev_id, + num_input_nodes, + input_keys, + input_shape_indptr, + input_shape_data, + &out); + + // Just a big enough memory 1000x1000x3 + int image_size = width * height * channels; + std::vector image_data = std::vector(image_size); + + //-- Read Mean Data + GetMeanFile(test_file, image_data.data(), channels, cv::Size(width, height)); + + //-- Set Input Image + MXPredSetInput(out, "data", image_data.data(), image_size); + + //-- Do Predict Forward + MXPredForward(out); + + mx_uint output_index = 0; + + mx_uint *shape = 0; + mx_uint shape_len; + + //-- Get Output Result + MXPredGetOutputShape(out, output_index, &shape, &shape_len); + + size_t size = 1; + for (mx_uint i = 0; i < shape_len; ++i) size *= shape[i]; + + std::vector data(size); + + MXPredGetOutput(out, output_index, &(data[0]), size); + + // Release Predictor + MXPredFree(out); + + // Synset path for your model, you have to modify it + std::vector synset = LoadSynset("model/Inception/synset.txt"); + + //-- Print Output Data + PrintOutputResult(data, synset); + + return 0; +} From 7bb03acd3917664fff90a4797506f32ef12a61ab Mon Sep 17 00:00:00 2001 From: skylook Date: Thu, 17 Dec 2015 20:18:28 +0800 Subject: [PATCH 08/15] Update CONTRIBUTORS.md --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 31f1356cc0a4..2868ecef1480 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -90,3 +90,4 @@ List of Contributors * [Zhang Chen](https://github.com/zhangchen-qinyinghua) * [Xianliang Wang](https://github.com/wangxianliang) * [Junru Shao](https://github.com/yzgysjr) +* [Xiao Liu](https://github.com/skylook) From 30bc343ea8e60d2dcb6fe66aefba756b7c0abd25 Mon Sep 17 00:00:00 2001 From: Mu Li Date: Thu, 17 Dec 2015 23:03:19 -0500 Subject: [PATCH 09/15] Update README.md --- tools/caffe_converter/README.md | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/tools/caffe_converter/README.md b/tools/caffe_converter/README.md index b7581345a1d2..6087796d81f9 100644 --- a/tools/caffe_converter/README.md +++ b/tools/caffe_converter/README.md @@ -1,30 +1,24 @@ # Convert Caffe Model to Mxnet Format -This tool converts a caffe model into mxnet's format. +### Build -## Build -If the Caffe python package is installed then no other step is required for -using. Otherwise, it requires Google protobuf to compile Caffe's model -format. One can either install protobuf by using package manager or build from -the source. For the latter, one can set `USE_DIST_KVSTORE = 1` when compiling -mxnet, namely +Either [Caffe's python package](http://caffe.berkeleyvision.org/installation.html) or [Google protobuf](https://developers.google.com/protocol-buffers/?hl=en) is required. The latter is often much easier to install: -``` -make -C ../.. USE_DIST_KVSTORE = 1 -``` +1. We first install the protobuf compiler. If you compiled mxnet with `USE_DIST_KVSTORE = 1` then it is already built. Otherwise, install `protobuf-compiler` by your favorate package manager, e.g. `sudo apt-get install protobuf-compiler` for ubuntu and `sudo yum install protobuf-compiler` for redhat/fedora. -Once `protobuf` is available, then run `make` in the current directory. +2. Then install the protobuf's python binding. For example `sudo pip install protobuf` + +Now we can build the tool by running `make` in the current directory. ## How to use Run ```python convert_model.py caffe_prototxt caffe_model save_model_name``` to convert the models. Run with ```-h``` for more details of parameters. - Or use `./run.sh model_name` to download and convert a model. Sample usage: `./run.sh vgg19` ## Note -* We have verified the results of VGG_16 model and BVLC_googlenet results from Caffe model zoo. +* We have verified the results of VGG_16/VGG_19 model and BVLC_googlenet results from Caffe model zoo. * The tool only supports single input and single output network. * The tool can only work with the L2LayerParameter in Caffe. From 60f58a54ddaa6fc60c524e0176d31fbce6feb1fa Mon Sep 17 00:00:00 2001 From: Mu Li Date: Thu, 17 Dec 2015 23:05:44 -0500 Subject: [PATCH 10/15] Update README.md --- tools/caffe_converter/README.md | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tools/caffe_converter/README.md b/tools/caffe_converter/README.md index 6087796d81f9..2e6eca1ea40c 100644 --- a/tools/caffe_converter/README.md +++ b/tools/caffe_converter/README.md @@ -4,20 +4,17 @@ Either [Caffe's python package](http://caffe.berkeleyvision.org/installation.html) or [Google protobuf](https://developers.google.com/protocol-buffers/?hl=en) is required. The latter is often much easier to install: -1. We first install the protobuf compiler. If you compiled mxnet with `USE_DIST_KVSTORE = 1` then it is already built. Otherwise, install `protobuf-compiler` by your favorate package manager, e.g. `sudo apt-get install protobuf-compiler` for ubuntu and `sudo yum install protobuf-compiler` for redhat/fedora. +1. We first install the protobuf compiler. If you compiled mxnet with `USE_DIST_KVSTORE = 1` then it is already built. Otherwise, install `protobuf-compiler` by your favor package manager, e.g. `sudo apt-get install protobuf-compiler` for ubuntu and `sudo yum install protobuf-compiler` for redhat/fedora. 2. Then install the protobuf's python binding. For example `sudo pip install protobuf` Now we can build the tool by running `make` in the current directory. -## How to use +### How to use -Run ```python convert_model.py caffe_prototxt caffe_model save_model_name``` to convert the models. Run with ```-h``` for more details of parameters. +Use `./run.sh model_name` to download and convert a model. E.g. `./run.sh vgg19` -Or use `./run.sh model_name` to download and convert a model. Sample usage: -`./run.sh vgg19` - -## Note +### Note * We have verified the results of VGG_16/VGG_19 model and BVLC_googlenet results from Caffe model zoo. * The tool only supports single input and single output network. From 8d9ee471e1e74c1c355a46ac9bd6c3977594b216 Mon Sep 17 00:00:00 2001 From: xjlc Date: Fri, 18 Dec 2015 15:37:07 +0100 Subject: [PATCH 11/15] Update build.md --- doc/build.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/build.md b/doc/build.md index 8df3d9d648bf..0c5f241c0cec 100644 --- a/doc/build.md +++ b/doc/build.md @@ -76,7 +76,7 @@ Then build mxnet ```bash git clone --recursive https://github.com/dmlc/mxnet -cd mxnet; cp make/osx.mk .;make -j4 +cd mxnet; cp make/osx.mk ./config.mk; make -j4 ``` Troubleshooting: From 6194cbf3082320e47681e34da5f4d7ee2f9280e3 Mon Sep 17 00:00:00 2001 From: Junyuan Xie Date: Sun, 13 Dec 2015 15:50:02 -0800 Subject: [PATCH 12/15] Softmax Activation Layer --- mshadow | 2 +- src/operator/activation-inl.h | 9 +- src/operator/activation.cc | 4 +- src/operator/cudnn_activation-inl.h | 92 +++++++++---- src/operator/prod_sum-inl.h | 179 +++++++++++++++++++++++++ src/operator/prod_sum.cc | 30 +++++ src/operator/prod_sum.cu | 17 +++ tests/python/unittest/test_operator.py | 24 ++++ 8 files changed, 327 insertions(+), 30 deletions(-) create mode 100644 src/operator/prod_sum-inl.h create mode 100644 src/operator/prod_sum.cc create mode 100644 src/operator/prod_sum.cu diff --git a/mshadow b/mshadow index 3e28fdf07110..629d09ff9323 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit 3e28fdf07110911535b264f2283de756cd9ae131 +Subproject commit 629d09ff93232f73dbd58d71eba55b24830cfc1e diff --git a/src/operator/activation-inl.h b/src/operator/activation-inl.h index f879adbe42f9..1c22522665ef 100644 --- a/src/operator/activation-inl.h +++ b/src/operator/activation-inl.h @@ -24,7 +24,7 @@ namespace op { namespace activation { enum ActivationOpInputs {kData}; enum ActivationOpOutputs {kOut}; -enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU}; +enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU, kSoftmax}; } // activation struct ActivationParam : public dmlc::Parameter { @@ -36,6 +36,7 @@ struct ActivationParam : public dmlc::Parameter { .add_enum("sigmoid", activation::kSigmoid) .add_enum("tanh", activation::kTanh) .add_enum("softrelu", activation::kSoftReLU) + .add_enum("softmax", activation::kSoftmax) .describe("Activation function to be applied."); } }; @@ -139,7 +140,11 @@ class ActivationProp : public OperatorProperty { const std::vector &in_data, const std::vector &out_data) const override { #if MXNET_USE_CUDNN == 1 - return {out_grad[activation::kOut], out_data[activation::kOut], in_data[activation::kData]}; + if (param_.act_type == activation::kSoftmax) { + return {out_grad[activation::kOut], out_data[activation::kOut]}; + } else { + return {out_grad[activation::kOut], out_data[activation::kOut], in_data[activation::kData]}; + } #else return {out_grad[activation::kOut], out_data[activation::kOut]}; #endif // MXNET_USE_CUDNN diff --git a/src/operator/activation.cc b/src/operator/activation.cc index 8cc904a5de91..b8da8c81c49a 100644 --- a/src/operator/activation.cc +++ b/src/operator/activation.cc @@ -34,7 +34,9 @@ Operator *ActivationProp::CreateOperator(Context ctx) const { DMLC_REGISTER_PARAMETER(ActivationParam); MXNET_REGISTER_OP_PROPERTY(Activation, ActivationProp) -.describe("Apply activation function to input.") +.describe("Apply activation function to input." + "Softmax Activation is only available with CUDNN on GPU" + "and will be computed at each location across channel if input is 4D.") .add_argument("data", "Symbol", "Input data to activation function.") .add_arguments(ActivationParam::__FIELDS__()); diff --git a/src/operator/cudnn_activation-inl.h b/src/operator/cudnn_activation-inl.h index 7e6acea7c952..dd67786f0f2e 100644 --- a/src/operator/cudnn_activation-inl.h +++ b/src/operator/cudnn_activation-inl.h @@ -29,6 +29,8 @@ class CuDNNActivationOp : public Operator { case activation::kTanh: mode_ = CUDNN_ACTIVATION_TANH; break; + case activation::kSoftmax: + break; default: LOG(FATAL) << "Not implmented"; break; @@ -51,14 +53,17 @@ class CuDNNActivationOp : public Operator { Stream *s = ctx.get_stream(); Tensor data; Tensor out; + cudnnSoftmaxMode_t softmax_mode; if (in_data[activation::kData].ndim() == 2) { Shape<4> dshape = Shape4(in_data[activation::kData].shape_[0], in_data[activation::kData].shape_[1], 1, 1); data = in_data[activation::kData].get_with_shape(dshape, s); out = out_data[activation::kOut].get_with_shape(dshape, s); + softmax_mode = CUDNN_SOFTMAX_MODE_INSTANCE; } else { data = in_data[activation::kData].get(s); out = out_data[activation::kOut].get(s); + softmax_mode = CUDNN_SOFTMAX_MODE_CHANNEL; } float alpha = 1.0f; float beta = 0.0f; @@ -74,14 +79,26 @@ class CuDNNActivationOp : public Operator { data.shape_[2], data.shape_[3]), CUDNN_STATUS_SUCCESS); } - CHECK_EQ(cudnnActivationForward(s->dnn_handle_, - mode_, - &alpha, - shape_desc_, - data.dptr_, - &beta, - shape_desc_, - out.dptr_), CUDNN_STATUS_SUCCESS); + if (param_.act_type == activation::kSoftmax) { + CHECK_EQ(cudnnSoftmaxForward(s->dnn_handle_, + CUDNN_SOFTMAX_ACCURATE, + softmax_mode, + &alpha, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + out.dptr_), CUDNN_STATUS_SUCCESS); + } else { + CHECK_EQ(cudnnActivationForward(s->dnn_handle_, + mode_, + &alpha, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + out.dptr_), CUDNN_STATUS_SUCCESS); + } } virtual void Backward(const OpContext &ctx, @@ -94,7 +111,9 @@ class CuDNNActivationOp : public Operator { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); - CHECK_EQ(in_data.size(), 1); + if (param_.act_type != activation::kSoftmax) { + CHECK_EQ(in_data.size(), 1); + } CHECK_EQ(out_data.size(), 1); CHECK_EQ(req.size(), 1); CHECK_EQ(in_grad.size(), 1); @@ -105,32 +124,53 @@ class CuDNNActivationOp : public Operator { Tensor data; Tensor output_data; Tensor input_grad; - if (in_data[activation::kData].ndim() == 2) { - Shape<4> dshape = Shape4(in_data[activation::kData].shape_[0], - in_data[activation::kData].shape_[1], 1, 1); - data = in_data[activation::kData].get_with_shape(dshape, s); + cudnnSoftmaxMode_t softmax_mode; + if (in_grad[activation::kData].ndim() == 2) { + Shape<4> dshape = Shape4(in_grad[activation::kData].shape_[0], + in_grad[activation::kData].shape_[1], 1, 1); + if (param_.act_type != activation::kSoftmax) { + data = in_data[activation::kData].get_with_shape(dshape, s); + } grad = out_grad[activation::kOut].get_with_shape(dshape, s); output_data = out_data[activation::kOut].get_with_shape(dshape, s); input_grad = in_grad[activation::kData].get_with_shape(dshape, s); + softmax_mode = CUDNN_SOFTMAX_MODE_INSTANCE; } else { - data = in_data[activation::kData].get(s); + if (param_.act_type != activation::kSoftmax) { + data = in_data[activation::kData].get(s); + } output_data = out_data[activation::kOut].get(s); grad = out_grad[activation::kOut].get(s); input_grad = in_grad[activation::kData].get(s); + softmax_mode = CUDNN_SOFTMAX_MODE_CHANNEL; } CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); - CHECK_EQ(cudnnActivationBackward(s->dnn_handle_, - mode_, - &alpha, - shape_desc_, - output_data.dptr_, - shape_desc_, - grad.dptr_, - shape_desc_, - data.dptr_, - &beta, - shape_desc_, - input_grad.dptr_), CUDNN_STATUS_SUCCESS); + if (param_.act_type == activation::kSoftmax) { + CHECK_EQ(cudnnSoftmaxBackward(s->dnn_handle_, + CUDNN_SOFTMAX_ACCURATE, + softmax_mode, + &alpha, + shape_desc_, + output_data.dptr_, + shape_desc_, + grad.dptr_, + &beta, + shape_desc_, + input_grad.dptr_), CUDNN_STATUS_SUCCESS); + } else { + CHECK_EQ(cudnnActivationBackward(s->dnn_handle_, + mode_, + &alpha, + shape_desc_, + output_data.dptr_, + shape_desc_, + grad.dptr_, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + input_grad.dptr_), CUDNN_STATUS_SUCCESS); + } } private: diff --git a/src/operator/prod_sum-inl.h b/src/operator/prod_sum-inl.h new file mode 100644 index 000000000000..79b8c68d7475 --- /dev/null +++ b/src/operator/prod_sum-inl.h @@ -0,0 +1,179 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file leaky_relu-inl.h + * \brief leaky relu family operator + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_PROD_SUM_INL_H_ +#define MXNET_OPERATOR_PROD_SUM_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" +#include "./mshadow_op.h" + +namespace mxnet { +namespace op { + +namespace prodsum { +enum ProdSumOpInputs {kLhs, kRhs}; +enum ProdSumOpOutputs {kOut}; +} // namespace prodsum + +struct ProdSumParam : public dmlc::Parameter { + index_t dot_dim; + DMLC_DECLARE_PARAMETER(ProdSumParam) { + DMLC_DECLARE_FIELD(dot_dim) + .describe("The dimension along with to do dot product."); + } +}; + +template +class ProdSumOp : public Operator { + public: + explicit ProdSumOp(ProdSumParam param) { + param_ = param; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + TShape lshape = in_data[prodsum::kLhs].shape_; + Shape<3> ishape = ShapeCheck(in_data); + Stream *s = ctx.get_stream(); + Tensor lhs = in_data[prodsum::kLhs] + .get_with_shape(ishape, s); + Tensor rhs = in_data[prodsum::kRhs] + .get_with_shape(ishape, s); + Tensor out = out_data[prodsum::kOut] + .get_with_shape(Shape2(ishape[0], ishape[2]), s); + Assign(out, req[prodsum::kOut], (reduce_with_axis(lhs*rhs))); + } + + virtual void Backward(const OpContext & ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + TShape lshape = in_data[prodsum::kLhs].shape_; + Shape<3> ishape = ShapeCheck(in_data); + Stream *s = ctx.get_stream(); + Tensor lhs = in_data[prodsum::kLhs] + .get_with_shape(ishape, s); + Tensor lhs_grad = in_grad[prodsum::kLhs] + .get_with_shape(ishape, s); + Tensor rhs = in_data[prodsum::kRhs] + .get_with_shape(ishape, s); + Tensor rhs_grad = in_grad[prodsum::kRhs] + .get_with_shape(ishape, s); + Tensor top = out_grad[prodsum::kOut] + .get_with_shape(Shape2(ishape[0], ishape[2]), s); + Assign(lhs_grad, req[prodsum::kLhs], (broadcast_with_axis<0>(top, ishape[1])*rhs)); + Assign(rhs_grad, req[prodsum::kRhs], (broadcast_with_axis<0>(top, ishape[1])*lhs)); + } + + private: + ProdSumParam param_; + + mshadow::Shape<3> ShapeCheck(const std::vector &in_data) { + index_t leading = 1, trailing = 1; + TShape lshape = in_data[prodsum::kLhs].shape_; + TShape rshape = in_data[prodsum::kRhs].shape_; + CHECK_EQ(lshape, rshape) << "Shape of two inputs must match"; + CHECK(lshape.ndim() > param_.dot_dim) + << "Inputs must have more dimensions than dot_dim"; + for (index_t i = 0; i < param_.dot_dim; ++i) { + leading *= lshape[i]; + } + for (index_t i = param_.dot_dim+1; i < lshape.ndim(); ++i) { + trailing *= lshape[i]; + } + return mshadow::Shape3(leading, lshape[param_.dot_dim], trailing); + } +}; // class ProdSumOp + +template +Operator* CreateOp(ProdSumParam type); + +#if DMLC_USE_CXX11 +class ProdSumProp : public OperatorProperty { + public: + void Init(const std::vector >& kwargs) override { + param_.Init(kwargs); + } + + std::map GetParams() const override { + return param_.__DICT__(); + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + TShape lshape = in_shape->at(0); + TShape rshape = in_shape->at(1); + CHECK_EQ(lshape, rshape) << "Shape of two inputs must match"; + CHECK(lshape.ndim() > param_.dot_dim) + << "Inputs must have more dimensions than dot_dim"; + std::vector s; + for (index_t i = 0; i < lshape.ndim(); ++i) { + if (i != param_.dot_dim) { + s.push_back(lshape[i]); + } + } + TShape oshape(s.begin(), s.end()); + out_shape->clear(); + out_shape->push_back(oshape); + return true; + } + + OperatorProperty* Copy() const override { + auto ptr = new ProdSumProp(); + ptr->param_ = param_; + return ptr; + } + + std::string TypeString() const override { + return "ProdSum"; + } + + // decalre dependency and inplace optimization options + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + return {in_data[prodsum::kLhs], in_data[prodsum::kRhs], out_grad[prodsum::kOut]}; + } + + std::vector ListArguments() const override { + return {"lhs", "rhs"}; + } + + std::vector ListOutputs() const override { + return {"output"}; + } + + Operator* CreateOperator(Context ctx) const override; + + private: + ProdSumParam param_; +}; +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_PROD_SUM_INL_H_ + diff --git a/src/operator/prod_sum.cc b/src/operator/prod_sum.cc new file mode 100644 index 000000000000..33e50eafe7de --- /dev/null +++ b/src/operator/prod_sum.cc @@ -0,0 +1,30 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file prod_sum.cc + * \brief product sum op + * \author Junyuan Xie +*/ +#include "./prod_sum-inl.h" +#include "./mshadow_op.h" + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(ProdSumParam param) { + return new ProdSumOp(param); +} + +// DO_BIND_DISPATCH comes from operator_common.h +Operator *ProdSumProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp, param_); +} + +DMLC_REGISTER_PARAMETER(ProdSumParam); + +MXNET_REGISTER_OP_PROPERTY(ProdSum, ProdSumProp) +.describe("Compute dot product along one dim of 2 tensors.") +.add_arguments(ProdSumParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet + diff --git a/src/operator/prod_sum.cu b/src/operator/prod_sum.cu new file mode 100644 index 000000000000..c345be586a31 --- /dev/null +++ b/src/operator/prod_sum.cu @@ -0,0 +1,17 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file prod_sum.cu + * \brief + * \author Junyuan Xie +*/ +#include "./prod_sum-inl.h" +#include "./mshadow_op.h" + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(ProdSumParam param) { + return new ProdSumOp(param); +} +} // op +} // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 4579bdbdf4cb..bf9d1dde8fa4 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -621,7 +621,31 @@ def test_nearest_upsampling(): shapes = [(1,3,base*root_scale*scale**(num_shape-1-i),base*root_scale*scale**(num_shape-1-i)) for i in range(num_shape)] check_nearest_upsampling_with_shape(shapes, scale, root_scale) +def check_prod_sum_with_shape(shape, dot_dim): + x = mx.sym.Variable('x') + X = mx.random.uniform(-1, 1, shape=shape, ctx=mx.cpu()) + dX = mx.nd.zeros(shape, ctx=mx.cpu()) + y = mx.sym.Variable('y') + Y = mx.random.uniform(-1, 1, shape=shape, ctx=mx.cpu()) + dY = mx.nd.zeros(shape, ctx=mx.cpu()) + z = mx.sym.ProdSum(lhs=x, rhs=y, dot_dim=dot_dim) + exe = z.bind(mx.cpu(), args={'x':X, 'y': Y}, args_grad={'x': dX, 'y': dY}) + exe.forward(is_train=True) + assert_allclose(exe.outputs[0].asnumpy(), np.sum(X.asnumpy()*Y.asnumpy(), axis=dot_dim), rtol=1e-4) + dZ = mx.nd.ones(exe.outputs[0].shape, ctx=mx.cpu()) + exe.backward(dZ) + assert_allclose(dX.asnumpy(), Y.asnumpy(), rtol=1e-4) + assert_allclose(dY.asnumpy(), X.asnumpy(), rtol=1e-4) + + +def test_prod_sum(): + check_prod_sum_with_shape((3,5,3), 0) + check_prod_sum_with_shape((3,5,3), 1) + check_prod_sum_with_shape((3,5,3), 2) + + if __name__ == '__main__': + test_prod_sum(); test_nearest_upsampling() test_binary_op_duplicate_input() test_elementwise_sum() From 5b9f40d74976b8700db68e7ac54a58b29dbaae08 Mon Sep 17 00:00:00 2001 From: Junyuan Xie Date: Wed, 16 Dec 2015 17:46:35 -0800 Subject: [PATCH 13/15] various small improvements Makefile target fix for .cu files; RMSE Metric; EXTRA_OPERATOR compiler flag --- Makefile | 2 +- python/mxnet/metric.py | 12 ++++++++++++ python/mxnet/recordio.py | 4 ++-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index e1db82c5e92d..b1b256b686d4 100644 --- a/Makefile +++ b/Makefile @@ -112,7 +112,7 @@ build/%.o: src/%.cc build/%_gpu.o: src/%.cu @mkdir -p $(@D) - $(NVCC) $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" -M build/$*_gpu.o $< >build/$*_gpu.d + $(NVCC) $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" -M -MT build/$*_gpu.o $< >build/$*_gpu.d $(NVCC) -c -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" $< lib/libmxnet.a: $(ALL_DEP) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 72a4adebce54..8e3efe511c0c 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -71,6 +71,18 @@ def update(self, labels, preds): self.sum_metric += numpy.sum(numpy.abs(label.asnumpy() - pred.asnumpy())) self.num_inst += numpy.prod(label.shape) +class RMSE(EvalMetric): + """Calculate Root Mean Squred Error loss""" + def __init__(self): + super(RMSE, self).__init__('rmse') + + def update(self, labels, preds): + assert len(labels) == len(preds) + for label, pred in zip(labels, preds): + assert label.shape == pred.shape + self.sum_metric += numpy.sqrt(numpy.mean((label.asnumpy() - pred.asnumpy())**2)) + self.num_inst += 1 + class CustomMetric(EvalMetric): """Custom evaluation metric that takes a NDArray function. diff --git a/python/mxnet/recordio.py b/python/mxnet/recordio.py index 5cbb272b1360..427b37a527e2 100644 --- a/python/mxnet/recordio.py +++ b/python/mxnet/recordio.py @@ -157,7 +157,7 @@ def unpack_img(s, iscolor=-1): img = cv2.imdecode(img, iscolor) return header, img -def pack_img(header, img, quality=80): +def pack_img(header, img, quality=80, format='.JPEG'): """pack an image into MXImageRecord Parameters @@ -175,6 +175,6 @@ def pack_img(header, img, quality=80): The packed string """ assert opencv_available - ret, buf = cv2.imencode('.JPEG', img, [cv2.IMWRITE_JPEG_QUALITY, quality]) + ret, buf = cv2.imencode(format, img, [cv2.IMWRITE_JPEG_QUALITY, quality]) assert ret return pack(header, buf.tostring()) From bf172a0bf3f04b055aff110466cf1b9e199184d4 Mon Sep 17 00:00:00 2001 From: Junyuan Xie Date: Wed, 16 Dec 2015 23:34:01 -0800 Subject: [PATCH 14/15] support nd cudnn activation --- src/operator/cudnn_activation-inl.h | 34 ++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/src/operator/cudnn_activation-inl.h b/src/operator/cudnn_activation-inl.h index dd67786f0f2e..04451e6a721f 100644 --- a/src/operator/cudnn_activation-inl.h +++ b/src/operator/cudnn_activation-inl.h @@ -61,8 +61,19 @@ class CuDNNActivationOp : public Operator { out = out_data[activation::kOut].get_with_shape(dshape, s); softmax_mode = CUDNN_SOFTMAX_MODE_INSTANCE; } else { - data = in_data[activation::kData].get(s); - out = out_data[activation::kOut].get(s); + Shape<4> dshape; + index_t size_left = in_data[activation::kData].Size(); + for (int i = 0; i < 3; ++i) { + if (i < in_data[activation::kData].ndim()) { + dshape[i] = in_data[activation::kData].shape_[i]; + } else { + dshape[i] = 1; + } + size_left /= dshape[i]; + } + dshape[3] = size_left; + data = in_data[activation::kData].get_with_shape(dshape, s); + out = out_data[activation::kOut].get_with_shape(dshape, s); softmax_mode = CUDNN_SOFTMAX_MODE_CHANNEL; } float alpha = 1.0f; @@ -136,12 +147,23 @@ class CuDNNActivationOp : public Operator { input_grad = in_grad[activation::kData].get_with_shape(dshape, s); softmax_mode = CUDNN_SOFTMAX_MODE_INSTANCE; } else { + Shape<4> dshape; + index_t size_left = in_grad[activation::kData].Size(); + for (int i = 0; i < 3; ++i) { + if (i < in_grad[activation::kData].ndim()) { + dshape[i] = in_grad[activation::kData].shape_[i]; + } else { + dshape[i] = 1; + } + size_left /= dshape[i]; + } + dshape[3] = size_left; if (param_.act_type != activation::kSoftmax) { - data = in_data[activation::kData].get(s); + data = in_data[activation::kData].get_with_shape(dshape, s); } - output_data = out_data[activation::kOut].get(s); - grad = out_grad[activation::kOut].get(s); - input_grad = in_grad[activation::kData].get(s); + output_data = out_data[activation::kOut].get_with_shape(dshape, s); + grad = out_grad[activation::kOut].get_with_shape(dshape, s); + input_grad = in_grad[activation::kData].get_with_shape(dshape, s); softmax_mode = CUDNN_SOFTMAX_MODE_CHANNEL; } CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); From ea259985d839f0a77c03703ab7271440bc9ddc2f Mon Sep 17 00:00:00 2001 From: Junyuan Xie Date: Wed, 16 Dec 2015 23:38:36 -0800 Subject: [PATCH 15/15] make ccsgd the default and add number of batch to predict --- Makefile | 28 ++- make/config.mk | 7 + make/osx.mk | 7 + python/mxnet/model.py | 13 +- python/mxnet/optimizer.py | 2 +- python/mxnet/recordio.py | 6 +- src/operator/activation-inl.h | 9 +- src/operator/cudnn_activation-inl.h | 86 +++------- src/operator/cudnn_softmax_activation-inl.h | 163 ++++++++++++++++++ src/operator/prod_sum-inl.h | 179 -------------------- src/operator/prod_sum.cc | 30 ---- src/operator/prod_sum.cu | 17 -- src/operator/softmax_activation-inl.h | 162 ++++++++++++++++++ src/operator/softmax_activation.cc | 38 +++++ src/operator/softmax_activation.cu | 27 +++ src/operator/softmax_output-inl.h | 4 +- tests/python/unittest/test_operator.py | 24 --- tools/caffe_converter/run.sh | 13 +- 18 files changed, 483 insertions(+), 332 deletions(-) create mode 100644 src/operator/cudnn_softmax_activation-inl.h delete mode 100644 src/operator/prod_sum-inl.h delete mode 100644 src/operator/prod_sum.cc delete mode 100644 src/operator/prod_sum.cu create mode 100644 src/operator/softmax_activation-inl.h create mode 100644 src/operator/softmax_activation.cc create mode 100644 src/operator/softmax_activation.cu diff --git a/Makefile b/Makefile index b1b256b686d4..b8f5408d3702 100644 --- a/Makefile +++ b/Makefile @@ -98,13 +98,27 @@ OBJ = $(patsubst src/%.cc, build/%.o, $(SRC)) CUSRC = $(wildcard src/*/*.cu) CUOBJ = $(patsubst src/%.cu, build/%_gpu.o, $(CUSRC)) +ifneq ($(EXTRA_OPERATORS), NONE) + EXTRA_SRC = $(wildcard $(EXTRA_OPERATORS)/*.cc $(EXTRA_OPERATORS)/*/*.cc) + EXTRA_OBJ = $(patsubst $(EXTRA_OPERATORS)/%.cc, $(EXTRA_OPERATORS)/build/%.o, $(EXTRA_SRC)) + EXTRA_CUSRC = $(wildcard $(EXTRA_OPERATORS)/*.cu $(EXTRA_OPERATORS)/*/*.cu) + EXTRA_CUOBJ = $(patsubst $(EXTRA_OPERATORS)/%.cu, $(EXTRA_OPERATORS)/build/%_gpu.o, $(EXTRA_CUSRC)) +else + EXTRA_SRC = + EXTRA_OBJ = + EXTRA_CUSRC = + EXTRA_CUOBJ = +endif + LIB_DEP += $(DMLC_CORE)/libdmlc.a -ALL_DEP = $(OBJ) $(LIB_DEP) +ALL_DEP = $(OBJ) $(EXTRA_OBJ) $(LIB_DEP) ifeq ($(USE_CUDA), 1) - ALL_DEP += $(CUOBJ) + ALL_DEP += $(CUOBJ) $(EXTRA_CUOBJ) LDFLAGS += -lnvrtc -lcuda endif + + build/%.o: src/%.cc @mkdir -p $(@D) $(CXX) -std=c++0x $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d @@ -115,6 +129,16 @@ build/%_gpu.o: src/%.cu $(NVCC) $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" -M -MT build/$*_gpu.o $< >build/$*_gpu.d $(NVCC) -c -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" $< +$(EXTRA_OPERATORS)/build/%.o: $(EXTRA_OPERATORS)/%.cc + @mkdir -p $(@D) + $(CXX) -std=c++0x $(CFLAGS) -Isrc/operator -MM -MT $(EXTRA_OPERATORS)/build/$*.o $< >$(EXTRA_OPERATORS)/build/$*.d + $(CXX) -std=c++0x -c $(CFLAGS) -Isrc/operator -c $< -o $@ + +$(EXTRA_OPERATORS)/build/%_gpu.o: $(EXTRA_OPERATORS)/%.cu + @mkdir -p $(@D) + $(NVCC) $(NVCCFLAGS) -Xcompiler "$(CFLAGS) -Isrc/operator" -M -MT $(EXTRA_OPERATORS)/build/$*_gpu.o $< >$(EXTRA_OPERATORS)/build/$*_gpu.d + $(NVCC) -c -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS) -Isrc/operator" $< + lib/libmxnet.a: $(ALL_DEP) @mkdir -p $(@D) ar crv $@ $(filter %.o, $?) diff --git a/make/config.mk b/make/config.mk index a5940c5a1a11..6585e5299f5e 100644 --- a/make/config.mk +++ b/make/config.mk @@ -95,3 +95,10 @@ LIBJVM=$(JAVA_HOME)/jre/lib/amd64/server # libcurl4-openssl-dev is required, it can be installed on Ubuntu by # sudo apt-get install -y libcurl4-openssl-dev USE_S3 = 0 + +#---------------------------- +# additional operators +#---------------------------- + +# path to folders containing projects specific operators that you don't want to put in src/operators +EXTRA_OPERATORS = diff --git a/make/osx.mk b/make/osx.mk index cf5d040732df..13a6389bba04 100644 --- a/make/osx.mk +++ b/make/osx.mk @@ -82,3 +82,10 @@ LIBJVM=$(JAVA_HOME)/jre/lib/amd64/server # libcurl4-openssl-dev is required, it can be installed on Ubuntu by # sudo apt-get install -y libcurl4-openssl-dev USE_S3 = 0 + +#---------------------------- +# additional operators +#---------------------------- + +# path to folders containing projects specific operators that you don't want to put in src/operators +EXTRA_OPERATORS = diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 70774321fdcc..7299a034f171 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -493,7 +493,7 @@ class FeedForward(BASE_ESTIMATOR): The additional keyword arguments passed to optimizer. """ def __init__(self, symbol, ctx=None, - num_epoch=None, epoch_size=None, optimizer='sgd', + num_epoch=None, epoch_size=None, optimizer='ccsgd', initializer=Uniform(0.01), numpy_batch_size=128, arg_params=None, aux_params=None, @@ -632,11 +632,13 @@ def _init_eval_iter(self, eval_data): 'NDArray/numpy.ndarray/list pair (i.e. tuple/list of length 2)') return eval_data - def predict(self, X): + def predict(self, X, num_batch=None): """Run the prediction, always only use one device. Parameters ---------- X : mxnet.DataIter + num_batch : int or None + the number of batch to run. Go though all batches if None Returns ------- y : numpy.ndarray or a list of numpy.ndarray if the network has multiple outputs. @@ -652,7 +654,12 @@ def predict(self, X): data_arrays = [self._pred_exec.arg_dict[name] for name in data_names] output_list = [[] for _ in range(len(self._pred_exec.outputs))] + i = 0 for batch in X: + if num_batch is not None and i == num_batch: + break + i += 1 + _load_data(batch, data_arrays) self._pred_exec.forward(is_train=False) padded = batch.pad @@ -803,7 +810,7 @@ def load(prefix, epoch, ctx=None, **kwargs): @staticmethod def create(symbol, X, y=None, ctx=None, - num_epoch=None, epoch_size=None, optimizer='sgd', initializer=Uniform(0.01), + num_epoch=None, epoch_size=None, optimizer='ccsgd', initializer=Uniform(0.01), eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', logger=None, work_load_list=None, **kwargs): diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 589158e0b075..738e39752edd 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -72,7 +72,7 @@ def _init_cc_optimizer(name, param_keys, param_vals): handle to the optimizer """ creator = OptimizerCreator() - check_call(_LIB.MXOptimizerFindCreator(ctypes.c_char_p(name), + check_call(_LIB.MXOptimizerFindCreator(c_str(name), ctypes.byref(creator))) assert creator, "Cannot find c++ implementation of optimizer \ registered with name "+name diff --git a/python/mxnet/recordio.py b/python/mxnet/recordio.py index 427b37a527e2..0deb7bac6cf4 100644 --- a/python/mxnet/recordio.py +++ b/python/mxnet/recordio.py @@ -157,7 +157,7 @@ def unpack_img(s, iscolor=-1): img = cv2.imdecode(img, iscolor) return header, img -def pack_img(header, img, quality=80, format='.JPEG'): +def pack_img(header, img, quality=80, img_fmt='.JPEG'): """pack an image into MXImageRecord Parameters @@ -175,6 +175,6 @@ def pack_img(header, img, quality=80, format='.JPEG'): The packed string """ assert opencv_available - ret, buf = cv2.imencode(format, img, [cv2.IMWRITE_JPEG_QUALITY, quality]) - assert ret + ret, buf = cv2.imencode(img_fmt, img, [cv2.IMWRITE_JPEG_QUALITY, quality]) + assert ret, 'failed encoding image' return pack(header, buf.tostring()) diff --git a/src/operator/activation-inl.h b/src/operator/activation-inl.h index 1c22522665ef..f879adbe42f9 100644 --- a/src/operator/activation-inl.h +++ b/src/operator/activation-inl.h @@ -24,7 +24,7 @@ namespace op { namespace activation { enum ActivationOpInputs {kData}; enum ActivationOpOutputs {kOut}; -enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU, kSoftmax}; +enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU}; } // activation struct ActivationParam : public dmlc::Parameter { @@ -36,7 +36,6 @@ struct ActivationParam : public dmlc::Parameter { .add_enum("sigmoid", activation::kSigmoid) .add_enum("tanh", activation::kTanh) .add_enum("softrelu", activation::kSoftReLU) - .add_enum("softmax", activation::kSoftmax) .describe("Activation function to be applied."); } }; @@ -140,11 +139,7 @@ class ActivationProp : public OperatorProperty { const std::vector &in_data, const std::vector &out_data) const override { #if MXNET_USE_CUDNN == 1 - if (param_.act_type == activation::kSoftmax) { - return {out_grad[activation::kOut], out_data[activation::kOut]}; - } else { - return {out_grad[activation::kOut], out_data[activation::kOut], in_data[activation::kData]}; - } + return {out_grad[activation::kOut], out_data[activation::kOut], in_data[activation::kData]}; #else return {out_grad[activation::kOut], out_data[activation::kOut]}; #endif // MXNET_USE_CUDNN diff --git a/src/operator/cudnn_activation-inl.h b/src/operator/cudnn_activation-inl.h index 04451e6a721f..8d7edf3e8411 100644 --- a/src/operator/cudnn_activation-inl.h +++ b/src/operator/cudnn_activation-inl.h @@ -29,8 +29,6 @@ class CuDNNActivationOp : public Operator { case activation::kTanh: mode_ = CUDNN_ACTIVATION_TANH; break; - case activation::kSoftmax: - break; default: LOG(FATAL) << "Not implmented"; break; @@ -53,13 +51,11 @@ class CuDNNActivationOp : public Operator { Stream *s = ctx.get_stream(); Tensor data; Tensor out; - cudnnSoftmaxMode_t softmax_mode; if (in_data[activation::kData].ndim() == 2) { Shape<4> dshape = Shape4(in_data[activation::kData].shape_[0], in_data[activation::kData].shape_[1], 1, 1); data = in_data[activation::kData].get_with_shape(dshape, s); out = out_data[activation::kOut].get_with_shape(dshape, s); - softmax_mode = CUDNN_SOFTMAX_MODE_INSTANCE; } else { Shape<4> dshape; index_t size_left = in_data[activation::kData].Size(); @@ -74,7 +70,6 @@ class CuDNNActivationOp : public Operator { dshape[3] = size_left; data = in_data[activation::kData].get_with_shape(dshape, s); out = out_data[activation::kOut].get_with_shape(dshape, s); - softmax_mode = CUDNN_SOFTMAX_MODE_CHANNEL; } float alpha = 1.0f; float beta = 0.0f; @@ -90,26 +85,14 @@ class CuDNNActivationOp : public Operator { data.shape_[2], data.shape_[3]), CUDNN_STATUS_SUCCESS); } - if (param_.act_type == activation::kSoftmax) { - CHECK_EQ(cudnnSoftmaxForward(s->dnn_handle_, - CUDNN_SOFTMAX_ACCURATE, - softmax_mode, - &alpha, - shape_desc_, - data.dptr_, - &beta, - shape_desc_, - out.dptr_), CUDNN_STATUS_SUCCESS); - } else { - CHECK_EQ(cudnnActivationForward(s->dnn_handle_, - mode_, - &alpha, - shape_desc_, - data.dptr_, - &beta, - shape_desc_, - out.dptr_), CUDNN_STATUS_SUCCESS); - } + CHECK_EQ(cudnnActivationForward(s->dnn_handle_, + mode_, + &alpha, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + out.dptr_), CUDNN_STATUS_SUCCESS); } virtual void Backward(const OpContext &ctx, @@ -122,9 +105,7 @@ class CuDNNActivationOp : public Operator { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); - if (param_.act_type != activation::kSoftmax) { - CHECK_EQ(in_data.size(), 1); - } + CHECK_EQ(in_data.size(), 1); CHECK_EQ(out_data.size(), 1); CHECK_EQ(req.size(), 1); CHECK_EQ(in_grad.size(), 1); @@ -135,17 +116,13 @@ class CuDNNActivationOp : public Operator { Tensor data; Tensor output_data; Tensor input_grad; - cudnnSoftmaxMode_t softmax_mode; if (in_grad[activation::kData].ndim() == 2) { Shape<4> dshape = Shape4(in_grad[activation::kData].shape_[0], in_grad[activation::kData].shape_[1], 1, 1); - if (param_.act_type != activation::kSoftmax) { - data = in_data[activation::kData].get_with_shape(dshape, s); - } + data = in_data[activation::kData].get_with_shape(dshape, s); grad = out_grad[activation::kOut].get_with_shape(dshape, s); output_data = out_data[activation::kOut].get_with_shape(dshape, s); input_grad = in_grad[activation::kData].get_with_shape(dshape, s); - softmax_mode = CUDNN_SOFTMAX_MODE_INSTANCE; } else { Shape<4> dshape; index_t size_left = in_grad[activation::kData].Size(); @@ -158,41 +135,24 @@ class CuDNNActivationOp : public Operator { size_left /= dshape[i]; } dshape[3] = size_left; - if (param_.act_type != activation::kSoftmax) { - data = in_data[activation::kData].get_with_shape(dshape, s); - } + data = in_data[activation::kData].get_with_shape(dshape, s); output_data = out_data[activation::kOut].get_with_shape(dshape, s); grad = out_grad[activation::kOut].get_with_shape(dshape, s); input_grad = in_grad[activation::kData].get_with_shape(dshape, s); - softmax_mode = CUDNN_SOFTMAX_MODE_CHANNEL; } CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); - if (param_.act_type == activation::kSoftmax) { - CHECK_EQ(cudnnSoftmaxBackward(s->dnn_handle_, - CUDNN_SOFTMAX_ACCURATE, - softmax_mode, - &alpha, - shape_desc_, - output_data.dptr_, - shape_desc_, - grad.dptr_, - &beta, - shape_desc_, - input_grad.dptr_), CUDNN_STATUS_SUCCESS); - } else { - CHECK_EQ(cudnnActivationBackward(s->dnn_handle_, - mode_, - &alpha, - shape_desc_, - output_data.dptr_, - shape_desc_, - grad.dptr_, - shape_desc_, - data.dptr_, - &beta, - shape_desc_, - input_grad.dptr_), CUDNN_STATUS_SUCCESS); - } + CHECK_EQ(cudnnActivationBackward(s->dnn_handle_, + mode_, + &alpha, + shape_desc_, + output_data.dptr_, + shape_desc_, + grad.dptr_, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + input_grad.dptr_), CUDNN_STATUS_SUCCESS); } private: diff --git a/src/operator/cudnn_softmax_activation-inl.h b/src/operator/cudnn_softmax_activation-inl.h new file mode 100644 index 000000000000..9b904f504d43 --- /dev/null +++ b/src/operator/cudnn_softmax_activation-inl.h @@ -0,0 +1,163 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file cudnn_activation-inl.h + * \brief + * \author Bing Xu +*/ + +#ifndef MXNET_OPERATOR_CUDNN_SOFTMAX_ACTIVATION_INL_H_ +#define MXNET_OPERATOR_CUDNN_SOFTMAX_ACTIVATION_INL_H_ +#include +#include +#include "./softmax_activation-inl.h" + +namespace mxnet { +namespace op { +class CuDNNSoftmaxActivationOp : public Operator { + public: + explicit CuDNNSoftmaxActivationOp(SoftmaxActivationParam param) { + this->param_ = param; + init_cudnn_ = false; + dtype_ = CUDNN_DATA_FLOAT; + } + + ~CuDNNSoftmaxActivationOp() { + CHECK_EQ(cudnnDestroyTensorDescriptor(shape_desc_), CUDNN_STATUS_SUCCESS); + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + Stream *s = ctx.get_stream(); + Tensor data; + Tensor out; + cudnnSoftmaxMode_t softmax_mode; + if (param_.type == softmax_activation::kInstance) { + CHECK_EQ(in_data[softmax_activation::kData].ndim(), 2) + << "Input need to have 2 dimensions when type=instance."; + Shape<4> dshape = Shape4(in_data[softmax_activation::kData].shape_[0], + in_data[softmax_activation::kData].shape_[1], 1, 1); + data = in_data[softmax_activation::kData].get_with_shape(dshape, s); + out = out_data[softmax_activation::kOut].get_with_shape(dshape, s); + softmax_mode = CUDNN_SOFTMAX_MODE_INSTANCE; + } else { + CHECK_GE(in_data[softmax_activation::kData].ndim(), 3) + << "Input need to have a least 3 dimensions when type=channel"; + Shape<4> dshape; + index_t size_left = in_data[softmax_activation::kData].Size(); + for (int i = 0; i < 3; ++i) { + if (i < in_data[softmax_activation::kData].ndim()) { + dshape[i] = in_data[softmax_activation::kData].shape_[i]; + } else { + dshape[i] = 1; + } + size_left /= dshape[i]; + } + dshape[3] = size_left; + data = in_data[softmax_activation::kData].get_with_shape(dshape, s); + out = out_data[softmax_activation::kOut].get_with_shape(dshape, s); + softmax_mode = CUDNN_SOFTMAX_MODE_CHANNEL; + } + float alpha = 1.0f; + float beta = 0.0f; + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + if (!init_cudnn_) { + init_cudnn_ = true; + CHECK_EQ(cudnnCreateTensorDescriptor(&shape_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnSetTensor4dDescriptor(shape_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + data.shape_[0], + data.shape_[1], + data.shape_[2], + data.shape_[3]), CUDNN_STATUS_SUCCESS); + } + CHECK_EQ(cudnnSoftmaxForward(s->dnn_handle_, + CUDNN_SOFTMAX_ACCURATE, + softmax_mode, + &alpha, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + out.dptr_), CUDNN_STATUS_SUCCESS); + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1); + CHECK_EQ(out_data.size(), 1); + CHECK_EQ(req.size(), 1); + CHECK_EQ(in_grad.size(), 1); + float alpha = 1.0f; + float beta = 0.0f; + Stream *s = ctx.get_stream(); + Tensor grad; + Tensor data; + Tensor output_data; + Tensor input_grad; + cudnnSoftmaxMode_t softmax_mode; + if (param_.type == softmax_activation::kInstance) { + CHECK_EQ(in_grad[softmax_activation::kData].ndim(), 2) + << "Input need to have 2 dimensions when type=instance."; + Shape<4> dshape = Shape4(in_grad[softmax_activation::kData].shape_[0], + in_grad[softmax_activation::kData].shape_[1], 1, 1); + grad = out_grad[softmax_activation::kOut].get_with_shape(dshape, s); + output_data = out_data[softmax_activation::kOut].get_with_shape(dshape, s); + input_grad = in_grad[softmax_activation::kData].get_with_shape(dshape, s); + softmax_mode = CUDNN_SOFTMAX_MODE_INSTANCE; + } else { + CHECK_GE(in_grad[softmax_activation::kData].ndim(), 3) + << "Input need to have a least 3 dimensions when type=channel"; + Shape<4> dshape; + index_t size_left = in_grad[softmax_activation::kData].Size(); + for (int i = 0; i < 3; ++i) { + if (i < in_grad[softmax_activation::kData].ndim()) { + dshape[i] = in_grad[softmax_activation::kData].shape_[i]; + } else { + dshape[i] = 1; + } + size_left /= dshape[i]; + } + dshape[3] = size_left; + output_data = out_data[softmax_activation::kOut].get_with_shape(dshape, s); + grad = out_grad[softmax_activation::kOut].get_with_shape(dshape, s); + input_grad = in_grad[softmax_activation::kData].get_with_shape(dshape, s); + softmax_mode = CUDNN_SOFTMAX_MODE_CHANNEL; + } + CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); + CHECK_EQ(cudnnSoftmaxBackward(s->dnn_handle_, + CUDNN_SOFTMAX_ACCURATE, + softmax_mode, + &alpha, + shape_desc_, + output_data.dptr_, + shape_desc_, + grad.dptr_, + &beta, + shape_desc_, + input_grad.dptr_), CUDNN_STATUS_SUCCESS); + } + + private: + bool init_cudnn_; + cudnnDataType_t dtype_; + cudnnTensorDescriptor_t shape_desc_; + SoftmaxActivationParam param_; +}; // class CuDNNSoftmaxActivationOp +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_CUDNN_SOFTMAX_ACTIVATION_INL_H_ diff --git a/src/operator/prod_sum-inl.h b/src/operator/prod_sum-inl.h deleted file mode 100644 index 79b8c68d7475..000000000000 --- a/src/operator/prod_sum-inl.h +++ /dev/null @@ -1,179 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file leaky_relu-inl.h - * \brief leaky relu family operator - * \author Bing Xu -*/ -#ifndef MXNET_OPERATOR_PROD_SUM_INL_H_ -#define MXNET_OPERATOR_PROD_SUM_INL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "./operator_common.h" -#include "./mshadow_op.h" - -namespace mxnet { -namespace op { - -namespace prodsum { -enum ProdSumOpInputs {kLhs, kRhs}; -enum ProdSumOpOutputs {kOut}; -} // namespace prodsum - -struct ProdSumParam : public dmlc::Parameter { - index_t dot_dim; - DMLC_DECLARE_PARAMETER(ProdSumParam) { - DMLC_DECLARE_FIELD(dot_dim) - .describe("The dimension along with to do dot product."); - } -}; - -template -class ProdSumOp : public Operator { - public: - explicit ProdSumOp(ProdSumParam param) { - param_ = param; - } - - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - TShape lshape = in_data[prodsum::kLhs].shape_; - Shape<3> ishape = ShapeCheck(in_data); - Stream *s = ctx.get_stream(); - Tensor lhs = in_data[prodsum::kLhs] - .get_with_shape(ishape, s); - Tensor rhs = in_data[prodsum::kRhs] - .get_with_shape(ishape, s); - Tensor out = out_data[prodsum::kOut] - .get_with_shape(Shape2(ishape[0], ishape[2]), s); - Assign(out, req[prodsum::kOut], (reduce_with_axis(lhs*rhs))); - } - - virtual void Backward(const OpContext & ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - TShape lshape = in_data[prodsum::kLhs].shape_; - Shape<3> ishape = ShapeCheck(in_data); - Stream *s = ctx.get_stream(); - Tensor lhs = in_data[prodsum::kLhs] - .get_with_shape(ishape, s); - Tensor lhs_grad = in_grad[prodsum::kLhs] - .get_with_shape(ishape, s); - Tensor rhs = in_data[prodsum::kRhs] - .get_with_shape(ishape, s); - Tensor rhs_grad = in_grad[prodsum::kRhs] - .get_with_shape(ishape, s); - Tensor top = out_grad[prodsum::kOut] - .get_with_shape(Shape2(ishape[0], ishape[2]), s); - Assign(lhs_grad, req[prodsum::kLhs], (broadcast_with_axis<0>(top, ishape[1])*rhs)); - Assign(rhs_grad, req[prodsum::kRhs], (broadcast_with_axis<0>(top, ishape[1])*lhs)); - } - - private: - ProdSumParam param_; - - mshadow::Shape<3> ShapeCheck(const std::vector &in_data) { - index_t leading = 1, trailing = 1; - TShape lshape = in_data[prodsum::kLhs].shape_; - TShape rshape = in_data[prodsum::kRhs].shape_; - CHECK_EQ(lshape, rshape) << "Shape of two inputs must match"; - CHECK(lshape.ndim() > param_.dot_dim) - << "Inputs must have more dimensions than dot_dim"; - for (index_t i = 0; i < param_.dot_dim; ++i) { - leading *= lshape[i]; - } - for (index_t i = param_.dot_dim+1; i < lshape.ndim(); ++i) { - trailing *= lshape[i]; - } - return mshadow::Shape3(leading, lshape[param_.dot_dim], trailing); - } -}; // class ProdSumOp - -template -Operator* CreateOp(ProdSumParam type); - -#if DMLC_USE_CXX11 -class ProdSumProp : public OperatorProperty { - public: - void Init(const std::vector >& kwargs) override { - param_.Init(kwargs); - } - - std::map GetParams() const override { - return param_.__DICT__(); - } - - bool InferShape(std::vector *in_shape, - std::vector *out_shape, - std::vector *aux_shape) const override { - TShape lshape = in_shape->at(0); - TShape rshape = in_shape->at(1); - CHECK_EQ(lshape, rshape) << "Shape of two inputs must match"; - CHECK(lshape.ndim() > param_.dot_dim) - << "Inputs must have more dimensions than dot_dim"; - std::vector s; - for (index_t i = 0; i < lshape.ndim(); ++i) { - if (i != param_.dot_dim) { - s.push_back(lshape[i]); - } - } - TShape oshape(s.begin(), s.end()); - out_shape->clear(); - out_shape->push_back(oshape); - return true; - } - - OperatorProperty* Copy() const override { - auto ptr = new ProdSumProp(); - ptr->param_ = param_; - return ptr; - } - - std::string TypeString() const override { - return "ProdSum"; - } - - // decalre dependency and inplace optimization options - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { - return {in_data[prodsum::kLhs], in_data[prodsum::kRhs], out_grad[prodsum::kOut]}; - } - - std::vector ListArguments() const override { - return {"lhs", "rhs"}; - } - - std::vector ListOutputs() const override { - return {"output"}; - } - - Operator* CreateOperator(Context ctx) const override; - - private: - ProdSumParam param_; -}; -#endif // DMLC_USE_CXX11 -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_PROD_SUM_INL_H_ - diff --git a/src/operator/prod_sum.cc b/src/operator/prod_sum.cc deleted file mode 100644 index 33e50eafe7de..000000000000 --- a/src/operator/prod_sum.cc +++ /dev/null @@ -1,30 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file prod_sum.cc - * \brief product sum op - * \author Junyuan Xie -*/ -#include "./prod_sum-inl.h" -#include "./mshadow_op.h" - -namespace mxnet { -namespace op { -template<> -Operator *CreateOp(ProdSumParam param) { - return new ProdSumOp(param); -} - -// DO_BIND_DISPATCH comes from operator_common.h -Operator *ProdSumProp::CreateOperator(Context ctx) const { - DO_BIND_DISPATCH(CreateOp, param_); -} - -DMLC_REGISTER_PARAMETER(ProdSumParam); - -MXNET_REGISTER_OP_PROPERTY(ProdSum, ProdSumProp) -.describe("Compute dot product along one dim of 2 tensors.") -.add_arguments(ProdSumParam::__FIELDS__()); - -} // namespace op -} // namespace mxnet - diff --git a/src/operator/prod_sum.cu b/src/operator/prod_sum.cu deleted file mode 100644 index c345be586a31..000000000000 --- a/src/operator/prod_sum.cu +++ /dev/null @@ -1,17 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file prod_sum.cu - * \brief - * \author Junyuan Xie -*/ -#include "./prod_sum-inl.h" -#include "./mshadow_op.h" - -namespace mxnet { -namespace op { -template<> -Operator *CreateOp(ProdSumParam param) { - return new ProdSumOp(param); -} -} // op -} // namespace mxnet diff --git a/src/operator/softmax_activation-inl.h b/src/operator/softmax_activation-inl.h new file mode 100644 index 000000000000..ccf50d323c6c --- /dev/null +++ b/src/operator/softmax_activation-inl.h @@ -0,0 +1,162 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file softmax_activation-inl.h + * \brief SoftmaxActivation operator + * \author Junyuan Xie +*/ +#ifndef MXNET_OPERATOR_SOFTMAX_ACTIVATION_INL_H_ +#define MXNET_OPERATOR_SOFTMAX_ACTIVATION_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" + +namespace mxnet { +namespace op { +// Declare enumeration of input order to make code more intuitive. +// // These enums are only visible within this header +namespace softmax_activation { +enum SoftmaxActivationOpInputs {kData}; +enum SoftmaxActivationOpOutputs {kOut}; +enum SoftmaxActivationOpType {kInstance, kChannel}; +} // softmax_activation + +struct SoftmaxActivationParam : public dmlc::Parameter { + // use int for enumeration + int type; + DMLC_DECLARE_PARAMETER(SoftmaxActivationParam) { + DMLC_DECLARE_FIELD(type) + .add_enum("instance", softmax_activation::kInstance) + .add_enum("channel", softmax_activation::kChannel) + .set_default(softmax_activation::kInstance) + .describe("Softmax Mode. If set to instance, this operator will compute a " + "softmax for each instance in the batch; this is the default mode. " + "If set to channel, this operator will compute a num_channel-class softmax at " + "each position of each instance; this can be used for fully convolutional network, " + "image segmentation, etc."); + } +}; + +/** + * \brief This is the implementation of softmax_activation operator. + * \tparam xpu The device that the op will be executed on. + */ +template +class SoftmaxActivationOp : public Operator { + public: + explicit SoftmaxActivationOp(SoftmaxActivationParam p) { + this->param_ = p; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + // Stream *s = ctx.get_stream(); + // Tensor data = in_data[softmax_activation::kData].FlatTo2D(s); + // Tensor out = out_data[softmax_activation::kOut].FlatTo2D(s); + LOG(FATAL) << "non-cuDNN version not implemented yet."; + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1); + CHECK(in_data.size() == 1 && in_grad.size() == 1); + CHECK_EQ(req.size(), 1); + // Stream *s = ctx.get_stream(); + // Tensor m_out_grad = out_grad[softmax_activation::kOut].FlatTo2D(s); + // Tensor m_out_data = out_data[softmax_activation::kOut].FlatTo2D(s); + // Tensor m_in_grad = in_grad[softmax_activation::kData].FlatTo2D(s); + LOG(FATAL) << "non-cuDNN version not implemented yet."; + } + + private: + SoftmaxActivationParam param_; +}; // class SoftmaxActivationOp + +// Decalre Factory function, used for dispatch specialization +template +Operator* CreateOp(SoftmaxActivationParam type); + +#if DMLC_USE_CXX11 +class SoftmaxActivationProp : public OperatorProperty { + public: + void Init(const std::vector >& kwargs) override { + param_.Init(kwargs); + } + + std::map GetParams() const override { + return param_.__DICT__(); + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + using namespace mshadow; + CHECK_EQ(in_shape->size(), 1) << "Input:[data]"; + const TShape &dshape = in_shape->at(softmax_activation::kData); + if (dshape.ndim() == 0) return false; + out_shape->clear(); + out_shape->push_back(dshape); + return true; + } + + OperatorProperty* Copy() const override { + auto ptr = new SoftmaxActivationProp(); + ptr->param_ = param_; + return ptr; + } + + std::string TypeString() const override { + return "SoftmaxActivation"; + } + + // decalre dependency and inplace optimization options + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + return {out_grad[softmax_activation::kOut], out_data[softmax_activation::kOut]}; + } + + std::vector > BackwardInplaceOption( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &in_grad) const override { + return {{out_grad[softmax_activation::kOut], in_grad[softmax_activation::kData]}}; + } + + std::vector > ForwardInplaceOption( + const std::vector &in_data, + const std::vector &out_data) const override { + return {{in_data[softmax_activation::kData], out_data[softmax_activation::kOut]}}; + } + + Operator* CreateOperator(Context ctx) const override; + + private: + SoftmaxActivationParam param_; +}; +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_SOFTMAX_ACTIVATION_INL_H_ diff --git a/src/operator/softmax_activation.cc b/src/operator/softmax_activation.cc new file mode 100644 index 000000000000..89d795051eb3 --- /dev/null +++ b/src/operator/softmax_activation.cc @@ -0,0 +1,38 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file activation.cc + * \brief softmax_activation op + * \author Junyuan Xie +*/ +#include "./softmax_activation-inl.h" +#include "./mshadow_op.h" + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(SoftmaxActivationParam param) { + LOG(FATAL) << "Softmax Activation for internal layers is only supported " + "on GPU with cuDNN. Use SoftmaxOutput for loss layer."; + return new SoftmaxActivationOp(param); +} + +// DO_BIND_DISPATCH comes from operator_common.h +Operator *SoftmaxActivationProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp, param_); +} + +DMLC_REGISTER_PARAMETER(SoftmaxActivationParam); + +MXNET_REGISTER_OP_PROPERTY(SoftmaxActivation, SoftmaxActivationProp) +.describe("Apply softmax activation to input. This is intended for internal layers. " + "For output (loss layer) please use SoftmaxOutput. If type=instance, " + "this operator will compute a softmax for each instance in the batch; " + "this is the default mode. If type=channel, this operator will compute " + "a num_channel-class softmax at each position of each instance; this can " + "be used for fully convolutional network, image segmentation, etc.") +.add_argument("data", "Symbol", "Input data to activation function.") +.add_arguments(SoftmaxActivationParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet + diff --git a/src/operator/softmax_activation.cu b/src/operator/softmax_activation.cu new file mode 100644 index 000000000000..4720e3c69fb7 --- /dev/null +++ b/src/operator/softmax_activation.cu @@ -0,0 +1,27 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file softmax_activation.cu + * \brief + * \author Junyuan Xie +*/ +#include "./softmax_activation-inl.h" +#include "./mshadow_op.h" +#if MXNET_USE_CUDNN == 1 +#include "./cudnn_softmax_activation-inl.h" +#endif + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(SoftmaxActivationParam param) { +#if MXNET_USE_CUDNN == 1 + return new CuDNNSoftmaxActivationOp(param); +#else + LOG(FATAL) << "Softmax Activation for internal layers is only supported " + "on GPU with cuDNN. Use SoftmaxOutput for loss layer."; + return new SoftmaxActivationOp(param); +#endif // MXNET_USE_CUDNN +} +} // op +} // namespace mxnet + diff --git a/src/operator/softmax_output-inl.h b/src/operator/softmax_output-inl.h index 4bbdf474ceb0..60877a6b0c3c 100644 --- a/src/operator/softmax_output-inl.h +++ b/src/operator/softmax_output-inl.h @@ -2,7 +2,7 @@ * Copyright (c) 2015 by Contributors * \file softmax_output-inl.h * \brief - * \author Bing Xu + * \author Junyuan Xie */ #ifndef MXNET_OPERATOR_SOFTMAX_OUTPUT_INL_H_ #define MXNET_OPERATOR_SOFTMAX_OUTPUT_INL_H_ @@ -89,7 +89,7 @@ class SoftmaxOutputOp : public Operator { Tensor out = out_data[softmaxout_enum::kOut].get_with_shape(s3, s); Tensor grad = in_grad[softmaxout_enum::kData].get_with_shape(s3, s); SoftmaxGrad(grad, out, label); - grad *= param_.grad_scale; + grad *= param_.grad_scale/s3[2]; } else { Tensor label = in_data[softmaxout_enum::kLabel].get(s); Tensor out = out_data[softmaxout_enum::kOut].FlatTo2D(s); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index bf9d1dde8fa4..4579bdbdf4cb 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -621,31 +621,7 @@ def test_nearest_upsampling(): shapes = [(1,3,base*root_scale*scale**(num_shape-1-i),base*root_scale*scale**(num_shape-1-i)) for i in range(num_shape)] check_nearest_upsampling_with_shape(shapes, scale, root_scale) -def check_prod_sum_with_shape(shape, dot_dim): - x = mx.sym.Variable('x') - X = mx.random.uniform(-1, 1, shape=shape, ctx=mx.cpu()) - dX = mx.nd.zeros(shape, ctx=mx.cpu()) - y = mx.sym.Variable('y') - Y = mx.random.uniform(-1, 1, shape=shape, ctx=mx.cpu()) - dY = mx.nd.zeros(shape, ctx=mx.cpu()) - z = mx.sym.ProdSum(lhs=x, rhs=y, dot_dim=dot_dim) - exe = z.bind(mx.cpu(), args={'x':X, 'y': Y}, args_grad={'x': dX, 'y': dY}) - exe.forward(is_train=True) - assert_allclose(exe.outputs[0].asnumpy(), np.sum(X.asnumpy()*Y.asnumpy(), axis=dot_dim), rtol=1e-4) - dZ = mx.nd.ones(exe.outputs[0].shape, ctx=mx.cpu()) - exe.backward(dZ) - assert_allclose(dX.asnumpy(), Y.asnumpy(), rtol=1e-4) - assert_allclose(dY.asnumpy(), X.asnumpy(), rtol=1e-4) - - -def test_prod_sum(): - check_prod_sum_with_shape((3,5,3), 0) - check_prod_sum_with_shape((3,5,3), 1) - check_prod_sum_with_shape((3,5,3), 2) - - if __name__ == '__main__': - test_prod_sum(); test_nearest_upsampling() test_binary_op_duplicate_input() test_elementwise_sum() diff --git a/tools/caffe_converter/run.sh b/tools/caffe_converter/run.sh index 83d74fe47774..65876cc42934 100755 --- a/tools/caffe_converter/run.sh +++ b/tools/caffe_converter/run.sh @@ -1,7 +1,7 @@ #!/bin/bash if [[ $# -ne 1 ]]; then echo "usage: $0 model_name" - echo " model_name: vgg19, ..." + echo " model_name: [vgg16|vgg19], ..." exit -1 fi @@ -16,6 +16,17 @@ if [[ $1 == "vgg19" ]]; then echo "converting" python `dirname $0`/convert_model.py VGG_ILSVRC_19_layers_deploy.prototxt VGG_ILSVRC_19_layers.caffemodel vgg19 +elif [[ $1 == "vgg16" ]]; then + if [[ ! -f VGG_ILSVRC_16_layers_deploy.prototxt ]]; then + wget -c https://gist.githubusercontent.com/ksimonyan/211839e770f7b538e2d8/raw/c3ba00e272d9f48594acef1f67e5fd12aff7a806/VGG_ILSVRC_16_layers_deploy.prototxt + fi + + if [[ ! -f VGG_ILSVRC_16_layers.caffemodel ]]; then + wget -c http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel + fi + + echo "converting" + python `dirname $0`/convert_model.py VGG_ILSVRC_16_layers_deploy.prototxt VGG_ILSVRC_16_layers.caffemodel vgg16 else echo "unsupported model: $1" fi