diff --git a/Makefile b/Makefile index d13688e6f0dc..50e9a21c50e8 100644 --- a/Makefile +++ b/Makefile @@ -58,14 +58,14 @@ endif BIN = test/api_registry_test test/test_storage OBJ = narray_op_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o fully_connected_cpu.o static_graph.o activation_cpu.o elementwise_sum_cpu.o +OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o elementwise_sum_cpu.o graph_executor.o pooling_cpu.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a LIB_DEP = $(DMLC_CORE)/libdmlc.a ifeq ($(USE_CUDA), 1) - CUOBJ += narray_op_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o + CUOBJ += narray_op_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o pooling_gpu.o endif .PHONY: clean all test lint doc @@ -81,6 +81,7 @@ narray.o: src/narray/narray.cc narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h symbol.o: src/symbol/symbol.cc +graph_executor.o: src/symbol/graph_executor.cc static_graph.o : src/symbol/static_graph.cc registry.o: src/registry.cc c_api.o: src/c_api.cc @@ -91,6 +92,9 @@ activation_cpu.o: src/operator/activation.cc activation_gpu.o: src/operator/activation.cu elementwise_sum_cpu.o: src/operator/elementwise_sum.cc elementwise_sum_gpu.o: src/operator/elementwise_sum.cu +pooling_cpu.o: src/operator/pooling.cc +pooling_gpu.o: src/operator/pooling.cu + lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) lib/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index fe035b21bc7f..38132cb169a5 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -34,12 +34,11 @@ typedef void *AtomicSymbolCreator; typedef void *SymbolHandle; /*! \brief handle to a AtomicSymbol */ typedef void *AtomicSymbolHandle; -/*! \brief handle to a NArrayOperator */ -typedef void *OperatorHandle; +/*! \brief handle to an Executor */ +typedef void *ExecutorHandle; /*! \brief handle to a DataIterator */ typedef void *DataIterHandle; - -/*! +/* * \brief return str message of the last error * all function in this file will return 0 when success * and -1 when an error occured, @@ -325,6 +324,7 @@ MXNET_DLL int MXSymbolCompose(SymbolHandle sym, * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data * The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional. * + * \param sym symbol handle * \param num_args numbe of input arguments. * \param keys the key of keyword args (optional) * \param arg_ind_ptr the head pointer of the rows in CSR @@ -351,63 +351,59 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym, const mx_uint ***out_shape_data, int *complete); //-------------------------------------------- -// Part 4: operator interface on NArray +// Part 4: Executor interface //-------------------------------------------- /*! - * \brief create operator from symbol - * \param sym the symbol to create operator from - * \param dev_mask device mask to indicate the device type - * \param dev_id the device id we want to bind the symbol to - * \param out the corresponding function handle - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXOpCreate(SymbolHandle sym, - int dev_mask, - int dev_id, - OperatorHandle *out); -/*! - * \brief free the operator handle - * \param op the handle to be freed + * \brief Executor forward method + * + * \param handle executor handle * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXOpFree(OperatorHandle op); +MXNET_DLL int MXExecutorForward(ExecutorHandle handle); /*! - * \brief return an array to describe the arguments - * of this operator - * \param out_size the size of output array - * \param out_array the array of parameter requirments + * \brief Excecutor run backward + * + * \param handle execute handle + * \param len lenth + * \param head_grads NArray handle for heads' gradient + * * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXOpDescribeArgs(mx_uint *out_size, - int **out_array); +MXNET_DLL int MXExecutorBackward(ExecutorHandle handle, + mx_uint len, + NArrayHandle *head_grads); + /*! - * \brief call forward on the operator - * \param op the operator handle - * \param in_data array of input narray to the operator - * \param out_data array of output NArray to hold the result + * \brief Get executor's head NArray + * + * \param handle executor handle + * \param out_size output narray vector size + * \param out out put narray handles * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXOpForward(OperatorHandle op, - NArrayHandle *in_data, - NArrayHandle *out_data); +MXNET_DLL int MXExecutorHeads(ExecutorHandle handle, + mx_uint *out_size, + NArrayHandle **out); + /*! - * \brief call backward on the operator - * \param op the operator handle - * \param grad_next array of output gradients - * \param in_data array of input narray to the operator - * \param out_data array of output narray to the operator - * \param out_grad array to holds the gradient on these input - * can be NULL if that position request is kNullOp - * \param reqs gradient request type + * \brief Generate Executor from symbol + * + * \param symbol_handle symbol handle + * \param len length + * \param in_args in args array + * \param arg_grad_store arg grads handle array + * \param grad_req_type grad req array + * \param out output executor handle * \return 0 when success, -1 when failure happens - * \sa mxnet::Operator::GradReqType */ -MXNET_DLL int MXOpBackward(OperatorHandle op, - NArrayHandle *grad_next, - NArrayHandle *in_data, - NArrayHandle *out_data, - NArrayHandle *out_grad, - mx_uint *reqs); +MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle, + int dev_mask, + int dev_id, + mx_uint len, + NArrayHandle *in_args, + NArrayHandle *arg_grad_store, + mx_uint *grad_req_type, + ExecutorHandle *out); //-------------------------------------------- // Part 5: IO Interface diff --git a/include/mxnet/context.h b/include/mxnet/context.h index 8dfa618ca180..700bb36f0abb 100644 --- a/include/mxnet/context.h +++ b/include/mxnet/context.h @@ -33,6 +33,14 @@ struct Context { inline bool operator==(const Context &b) const { return dev_mask == b.dev_mask && dev_id == b.dev_id; } + /*! + * \brief check if current context not equals another one + * \param b another context to compare + * \return whether they are not the same + */ + inline bool operator!=(const Context &b) const { + return !(*this == b); + } }; /*! diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 92257b3f0269..ed2b72bc4cc5 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -35,7 +35,7 @@ class NArray { */ NArray(const TShape &shape, Context ctx, bool delay_alloc = false) - : ptr_(new Chunk(shape, ctx, delay_alloc)) { + : ptr_(new Chunk(shape.Size(), ctx, delay_alloc)), shape_(shape), offset_(0) { } /*! * \brief constructing a static NArray that shares data with TBlob @@ -45,19 +45,20 @@ class NArray { * \param dev_id the device id this tensor sits at */ NArray(const TBlob &data, int dev_id) - : ptr_(new Chunk(data, dev_id)) { + : ptr_(new Chunk(data, dev_id)), shape_(data.shape_), offset_(0) { } /*! * \return the shape of current NArray */ inline const TShape &shape() const { - return ptr_->data.shape_; + return shape_; } /*! * \return the data TBlob */ - inline const TBlob &data() const { - return ptr_->data; + inline TBlob data() const { + return TBlob(static_cast(ptr_->shandle.dptr) + offset_, \ + shape_, ptr_->shandle.ctx.dev_mask); } /*! * \return the context of NArray, this function is only valid when the NArray is not empty @@ -123,6 +124,43 @@ class NArray { * \return the new copy */ NArray Copy(Context ctx) const; + /*! + * \brief Slice a NArray + * + * \param begin begin index in first dim + * \param end end index in first dim + * + * \return sliced NArray + */ + inline NArray Slice(index_t begin, index_t end) const { + NArray ret = *this; + CHECK_GE(shape_.ndim(), 0) << "NArray not initialized"; + CHECK_GE(shape_[0], end) << "Chunk is smaller than required"; + size_t length = 1; + if (shape_.ndim() == 1) { + ret.offset_= begin; + } else { + for (index_t i = 1; i < shape_.ndim(); ++i) { + length *= shape_[i]; + } + ret.offset_ = begin * length; + } + ret.shape_[0] = end - begin; + return ret; + } + /*! + * \brief Reshape current NArray + * + * \param shape new shape + * \return NArray in new shape + */ + inline NArray Reshape(const TShape &shape) const { + CHECK_GE(shape_.Size(), shape.Size()) \ + << "required shape is larger than chunk"; + NArray ret = *this; + ret.shape_ = shape; + return ret; + } private: /*! \brief the real data chunk that backs NArray */ @@ -131,8 +169,6 @@ class NArray { Storage::Handle shandle; /*! \brief variable from DAG engine */ DAGEngine::Variable var; - /*! \brief holds the data content */ - TBlob data; /*! * \brief if this is true, this means the data do not come * from Storage, and do not need to be freed @@ -146,25 +182,25 @@ class NArray { } /*! \brief construct from static data */ Chunk(const TBlob &data, int dev_id) - : data(data), - static_data(true), + : static_data(true), delay_alloc(false) { var = DAGEngine::Get()->NewVar(); shandle.ctx = Context(data.dev_mask_, dev_id); + shandle.dptr = data.dptr_; + shandle.size = data.shape_.Size() * sizeof(real_t); } /*! \brief construct a new chunk */ - Chunk(const TShape &shape, Context ctx, bool delay_alloc_) + Chunk(uint64_t size, Context ctx, bool delay_alloc_) : static_data(false), delay_alloc(true) { var = DAGEngine::Get()->NewVar(); - data.shape_ = shape; + shandle.size = size * sizeof(real_t); shandle.ctx = ctx; if (!delay_alloc_) this->CheckAndAlloc(); } /*! \brief check if delay alloc is on, do alloc if not yet done */ inline void CheckAndAlloc(void) { if (delay_alloc) { - shandle = Storage::Get()->Alloc(data.shape_.Size() * sizeof(real_t), shandle.ctx); - data = TBlob(static_cast(shandle.dptr), data.shape_, shandle.ctx.dev_mask); + shandle = Storage::Get()->Alloc(shandle.size, shandle.ctx); delay_alloc = false; } } @@ -183,6 +219,11 @@ class NArray { }; /*! \brief internal data of NArray */ std::shared_ptr ptr_; + /*! \brief shape of current NArray */ + TShape shape_; + /*! \brief offset in chunk */ + size_t offset_; + // add friend to helper functions friend void CopyFromTo(const NArray &from, NArray *to); template diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 0fa1fb6a0571..e60afe6948a7 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -8,6 +8,7 @@ #define MXNET_OPERATOR_H_ #include +#include #include #include #include @@ -108,7 +109,9 @@ class Operator { const std::vector &in_data, const std::vector &out_data, const std::vector &req, - const std::vector &in_grad) = 0; + const std::vector &in_grad) { + LOG(FATAL) << "Backward is not implemented"; + } }; #if DMLC_USE_CXX11 @@ -255,28 +258,36 @@ class OperatorProperty { * This function enables optimization to reuse memory of inputs in output. * Only override when necessary, by default in-place is disabled. * + * The reason for void* type in the out_data is to distinguish the order + * of mappings between the two, compiler will report error when + * in_data and out_data's order in the pair get reversed. + * * \code * // The following code says out_data[0] can share data with in_data[0] - * vector > ForwardInplaceOption(const vector &in_data, - * const vector &out_data) const { - * return {{out_data[0], in_data[0]}}; + * vector > ForwardInplaceOption(const vector &in_data, + * const vector &out_data) const { + * return {{in_data[0], out_data[0]}}; * } * \endcode * \param in_data The input data in forward pass. * \param out_data The output data in forward pass. - * \return list of pair of integers taken from the inputs vector, + * \return list of pair of that maps input->output, * indicating possible in place operations. */ - virtual std::vector > ForwardInplaceOption( + virtual std::vector > ForwardInplaceOption( const std::vector &in_data, - const std::vector &out_data) const { - return std::vector >(); + const std::vector &out_data) const { + return std::vector >(); } /*! * \brief Get possible backward inplace options. * This function enables optimization to reuse memory of inputs in output. * Only override when necessary, by default in-place is disabled. * + * The reason for void* type in the in_grad is to distinguish the order + * of mappings between the two, compiler will report error when + * in_data and out_data's order in the pair get reversed. + * * \code * // The following code says in_grad[0] can share data with in_data[0] * vector > BackwardInplaceOption( @@ -284,22 +295,22 @@ class OperatorProperty { * const std::vector &in_data, * const std::vector &out_data, * const std::vector &in_grad) const { - * return {in_grad[0], in_data[0]}}; + * return {in_data[0], in_grad[0]}}; * } * \endcode * \param in_data The input data in forward pass. * \param out_data The output data in forward pass. * \param in_grad Gradient of inputs in backward pass. * \param out_grad Gradient of outputs in backward pass. - * \return list of pair of integers taken from the inputs vector, + * \return list of pair of that maps input->output, * indicating possible in place operations. */ - virtual std::vector > BackwardInplaceOption( + virtual std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, - const std::vector &in_grad) const { - return std::vector >(); + const std::vector &in_grad) const { + return std::vector >(); } /*! * \brief Get Backward Input Dependency for generic types of data. @@ -314,31 +325,35 @@ class OperatorProperty { * \sa DeclareBackwardDependency */ template - inline std::vector BackwardInputs(const std::vector &in_data, - const std::vector &out_data, - const std::vector &out_grad) const { - int cnt = 0; - std::vector all_vec; - std::vector in_data_idx, out_data_idx, out_grad_idx; - for (size_t i = 0; i < in_data.size(); ++i) { - in_data_idx.push_back(cnt++); - all_vec.push_back(in_data[i]); + inline std::vector BackwardInputs(const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const { + int counter = 0; + std::vector out_grad_index(out_grad.size()); + std::vector in_data_index(in_data.size()); + std::vector out_data_index(out_data.size()); + for (size_t i = 0; i < out_grad_index.size(); ++i) { + out_grad_index[i] = counter++; } - for (size_t i = 0; i < out_data.size(); ++i) { - out_data_idx.push_back(cnt++); - all_vec.push_back(out_data[i]); + for (size_t i = 0; i < in_data_index.size(); ++i) { + in_data_index[i] = counter++; } - for (size_t i = 0; i < out_grad.size(); ++i) { - out_grad_idx.push_back(cnt++); - all_vec.push_back(out_data[i]); + for (size_t i = 0; i < out_data_index.size(); ++i) { + out_data_index[i] = counter++; } - std::vector ret_idx = this->DeclareBackwardDependency( - in_data_idx, out_data_idx, out_grad_idx); - std::vector ret; - for (size_t i = 0; i < ret_idx.size(); ++i) { - ret.push_back(all_vec[ret_idx[i]]); + std::vector all_data; + all_data.insert(all_data.end(), out_grad.begin(), out_grad.end()); + all_data.insert(all_data.end(), in_data.begin(), in_data.end()); + all_data.insert(all_data.end(), out_data.begin(), out_data.end()); + + std::vector ret_index = this->DeclareBackwardDependency( + out_grad_index, in_data_index, out_data_index); + + std::vector ret(ret_index.size()); + for (size_t i = 0; i < ret_index.size(); ++i) { + ret[i] = all_data[ret_index[i]]; } - return ret; + return std::move(ret); } /*! * \brief create OperatorProperty diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index e24c03a0cd0b..df06c4913de8 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -158,11 +158,17 @@ class StaticGraph { * The head and input of Backward pass will be returned by head_grad_nodes and arg_grads. * * \param head_grad_nodes used to store the created head gradient inputs for backward pass. -<<<<<<< HEAD * \param arg_grads used to store gradients to args, can be multiple one if an argument is used by operator */ void MakeBackwardPass(std::vector *head_grad_nodes, std::vector > *arg_grads); + + /*! + * \brief create a sum node that aggregates gradient together + * \param grad_source the source of the inputs. + * \return a created ElementWiseSum node + */ + static Node CreateSumNode(const std::vector &grad_source); }; /*! diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 8cb698aa8219..e30c77d382a3 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -69,7 +69,7 @@ def _load_lib(): FunctionHandle = ctypes.c_void_p SymbolCreatorHandle = ctypes.c_void_p SymbolHandle = ctypes.c_void_p - +ExecutorHandle = ctypes.c_void_p #---------------------------- # helper function definition #---------------------------- diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py new file mode 100644 index 000000000000..7352bfe2f289 --- /dev/null +++ b/python/mxnet/executor.py @@ -0,0 +1,57 @@ +# coding: utf-8 +""" code for executor. """ +from __future__ import absolute_import + +import ctypes +from .base import _LIB +from .base import c_array, c_str, mx_uint, NArrayHandle, ExecutorHandle +from .base import check_call +from .narray import NArray + +class Executor(object): + """ Executor is the actual executing object of MXNet.""" + def __init__(self, handle): + """Init an executor from handle + + Parameters + ---------- + handle: ExecutorHandle + ExecutorHandle generated by calling Bind + """ + if not isinstance(handle, ExecutorHandle): + raise TypeError("Handle type error") + self.handle = handle + + def forward(self): + """Do forward.""" + check_call(_LIB.MXExecutorForward(self.handle)) + + def backward(self, grads): + """Do backward on heads' gradient. + + Parameters + ---------- + grads: Array of NArray + heads' gradient + """ + for obj in grads: + if not isinstance(obj, NArray): + raise TypeError("inputs must be NArray") + narray = c_array(NArrayHandle, [item.handle for item in grads]) + check_call(_LIB.MXExecutorBackward(self.handle, len(grads), narray)) + + def heads(self): + """list all heads' output narray + + Returns + ------- + A list of narray binded to the heads of executor. + """ + # TODO: think of access, make heads read only. + # (consider support read only NArray(NArrayView)) + # Otherwise some of the internal might depends on out_data + # if user set the content of the head, the backward behavior can be incorrect. + out_size = mx_uint() + handles = ctypes.POINTER(NArrayHandle)() + check_call(_LIB.MXExecutorHeads(self.handle, ctypes.byref(out_size), ctypes.byref(handles))) + return [NArray(NArrayHandle(handles[i])) for i in range(out_size.value)] diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 0caa4b6a0a90..6c72442cb3f9 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -1,16 +1,18 @@ # coding: utf-8 -# pylint: disable=invalid-name, protected-access +# pylint: disable=invalid-name, protected-access, too-many-locals """Symbol support of mxnet""" from __future__ import absolute_import import ctypes from .base import _LIB -from .base import c_array, c_str, mx_uint -from .base import SymbolHandle +from .base import c_array, c_str, mx_uint, NArrayHandle, ExecutorHandle, SymbolHandle from .base import check_call +from .narray import NArray +from .context import Context +from .executor import Executor class Symbol(object): - """SymbolCreator is a function that takes Param and return symbol""" + """Symbol is symbolic graph of the mxnet.""" _registry = None @staticmethod @@ -162,7 +164,8 @@ def infer_shape(self, *args, **kwargs): The order is in the same order as list_returns() """ if len(args) != 0 and len(kwargs) != 0: - raise ValueError('Can only specify known argument shapes either by positional or kwargs way.') + raise ValueError('Can only specify known argument \ + shapes either by positional or kwargs way.') sdata = [] indptr = [0] if len(args) != 0: @@ -188,21 +191,23 @@ def infer_shape(self, *args, **kwargs): out_shape_ndim = ctypes.POINTER(mx_uint)() out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))() complete = ctypes.c_int() - check_call(_LIB.MXSymbolInferShape( - self.handle, len(indptr) - 1, - c_array(ctypes.c_char_p, keys), - c_array(mx_uint, indptr), - c_array(mx_uint, sdata), - ctypes.byref(arg_shape_size), - ctypes.byref(arg_shape_ndim), - ctypes.byref(arg_shape_data), - ctypes.byref(out_shape_size), - ctypes.byref(out_shape_ndim), - ctypes.byref(out_shape_data), + check_call(_LIB.MXSymbolInferShape( \ + self.handle, len(indptr) - 1, \ + c_array(ctypes.c_char_p, keys), \ + c_array(mx_uint, indptr), \ + c_array(mx_uint, sdata), \ + ctypes.byref(arg_shape_size), \ + ctypes.byref(arg_shape_ndim), \ + ctypes.byref(arg_shape_data), \ + ctypes.byref(out_shape_size), \ + ctypes.byref(out_shape_ndim), \ + ctypes.byref(out_shape_data), \ ctypes.byref(complete))) if complete.value != 0: - arg_shapes = [tuple(arg_shape_data[i][:arg_shape_ndim[i]]) for i in range(arg_shape_size.value)] - out_shapes = [tuple(out_shape_data[i][:out_shape_ndim[i]]) for i in range(out_shape_size.value)] + arg_shapes = [tuple(arg_shape_data[i][:arg_shape_ndim[i]]) \ + for i in range(arg_shape_size.value)] + out_shapes = [tuple(out_shape_data[i][:out_shape_ndim[i]]) \ + for i in range(out_shape_size.value)] return (arg_shapes, out_shapes) else: return (None, None) @@ -216,6 +221,40 @@ def debug_str(self): Debug string of the symbol. """ debug_str = ctypes.c_char_p() - check_call(_LIB.MXSymbolPrint( + check_call(_LIB.MXSymbolPrint( \ self.handle, ctypes.byref(debug_str))) return debug_str.value + + def bind(self, ctx, args, args_grad, reqs): + """bind current symbol to get an executor. + + Parameters + ---------- + ctx: Context + context executor to run on + args: Array of NArray + input args to the symbol + args_grad: Array of NArray + input args' gradient + reqs: Array of enum + graident requirements + """ + # TODO(bing): consider a more friendly interface + # For example, pass in args_grad by dict + + enum = {"null" : 0, "write_to" : 1, "in_place":2, "add_to" : 3} + if not isinstance(ctx, Context): + raise TypeError("Context type error") + args_handle = c_array(NArrayHandle, [item.handle for item in args]) + args_grad_handle = c_array(NArrayHandle, [item.handle for item in args_grad]) + reqs_array = c_array(mx_uint, [mx_uint(enum[item]) for item in reqs]) + handle = ExecutorHandle() + check_call(_LIB.MXExecutorBind(self.handle, + mx_uint(ctx.device_mask), + mx_uint(ctx.device_id), + len(args), + args_handle, + args_grad_handle, + reqs_array, + ctypes.byref(handle))) + return Executor(handle) diff --git a/python/test_mnist.py b/python/test_mnist.py new file mode 100644 index 000000000000..1075ceaf6206 --- /dev/null +++ b/python/test_mnist.py @@ -0,0 +1,137 @@ +# pylint: skip-file +import mxnet as mx +import numpy as np +import os, cPickle, gzip + +def Softmax(x): + batch, nidden = x.shape + maxes = np.max(x, axis=1) + x -= maxes.reshape(batch, 1) + x = np.exp(x) + norm = np.sum(x, axis=1) + prob = x / norm.reshape((batch, 1)) + return prob + +def CalAcc(out, label): + pred = np.argmax(out, axis=1) + return np.sum(pred == label) * 1.0 / out.shape[0] + +def SetGradient(out_grad, label): + assert(out_grad.shape[0] == label.shape[0]) + for i in xrange(label.shape[0]): + k = label[i] + out_grad[i][k] -= 1.0 + +# load data +class MNISTIter(object): + def __init__(self, which_set, batch_size=100): + if not os.path.exists('mnist.pkl.gz'): + os.system("wget http://deeplearning.net/data/mnist/mnist.pkl.gz") + f = gzip.open('mnist.pkl.gz', 'rb') + train_set, valid_set, test_set = cPickle.load(f) + f.close() + if which_set == 'train': + self.data = train_set[0] + self.label = np.asarray(train_set[1]) + elif which_set == 'valid': + self.data = valid_set[0] + self.label = np.asarray(valid_set[1]) + else: + self.data = test_set[0] + self.data = np.asarray(test_set[1]) + self.batch_size = batch_size + self.nbatch = self.data.shape[0] / batch_size + assert(self.data.shape[0] % batch_size == 0) # I am lazy + self.now_idx = -1 + def BeforeFirst(self): + self.now_idx = -1 + def Next(self): + self.now_idx += 1 + if self.now_idx == self.nbatch: + return False + return True + def Get(self): + if self.now_idx < 0: + raise Exception("Iterator is at head") + elif self.now_idx >= self.nbatch: + raise Exception("Iterator is at end") + start = self.now_idx * self.batch_size + end = (self.now_idx + 1) * self.batch_size + return (self.data[start:end, :], self.label[start:end]) + + + +# symbol net +batch_size = 100 +data = mx.sym.Variable('data') +fc1 = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=160) +act1 = mx.sym.Activation(data = fc1, name='relu1', type="relu") +fc2 = mx.sym.FullyConnected(data = act1, name='fc2', num_hidden=10) +args_list = fc2.list_arguments() +# infer shape +data_shape = (batch_size, 784) +arg_shapes, out_shapes = fc2.infer_shape(data=data_shape) +arg_narrays = [mx.narray.create(shape) for shape in arg_shapes] +grad_narrays = [mx.narray.create(shape) for shape in arg_shapes] +mom_narrays = [mx.narray.create(shape) for shape in arg_shapes] +inputs = dict(zip(args_list, arg_narrays)) + +np.random.seed(0) +# set random weight +for name, narray in inputs.items(): + if "weight" in name: + narray.numpy[:, :] = np.random.uniform(-0.001, 0.001, narray.numpy.shape) +req = ['write_to' for i in range(len(arg_narrays))] +# bind executer +# TODO(bing): think of a better bind interface +executor = fc2.bind(mx.Context('cpu'), arg_narrays, grad_narrays, req) +# update + +out_narray = executor.heads()[0] +grad_narray = mx.narray.create(out_narray.shape) + +epoch = 10 +momentum = 0.9 +lr = 0.001 +wd = 0.0004 + +def Update(mom, grad, weight): + weight.numpy[:] -= lr * grad.numpy[:] + +block = zip(mom_narrays, grad_narrays, arg_narrays) + + +train = MNISTIter("train", batch_size) +valid = MNISTIter("valid", batch_size) + +for i in xrange(epoch): + # train + print "Epoch %d" % i + train_acc = 0.0 + val_acc = 0.0 + while train.Next(): + data, label = train.Get() + inputs["data"].numpy[:] = data + executor.forward() + out_narray.numpy[:] = Softmax(out_narray.numpy) + train_acc += CalAcc(out_narray.numpy, label) + grad_narray.numpy[:] = out_narray.numpy + SetGradient(grad_narray.numpy, label) + executor.backward([grad_narray]) + + for mom, grad, weight in block: + Update(mom, grad, weight) + + # evaluate + while valid.Next(): + data, label = valid.Get() + inputs["data"].numpy[:] = data + executor.forward() + val_acc += CalAcc(out_narray.numpy, label) + print "Train Acc: ", train_acc / train.nbatch + print "Valid Acc: ", val_acc / valid.nbatch + train.BeforeFirst() + valid.BeforeFirst() + + + diff --git a/src/c_api.cc b/src/c_api.cc index ed5446fc816a..3d5e03cc0748 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -40,6 +40,8 @@ struct MXAPIThreadLocalEntry { std::vector ret_vec_str; /*! \brief result holder for returning string pointers */ std::vector ret_vec_charp; + /*! \brief result holder for returning handles */ + std::vector ret_handles; /*! \brief result holder for returning shapes */ std::vector arg_shapes, out_shapes; /*! \brief result holder for returning shape dimensions */ @@ -480,3 +482,68 @@ int MXSymbolInferShape(SymbolHandle sym, } API_END(); } + +int MXExecutorForward(ExecutorHandle handle) { + API_BEGIN(); + Executor *exec = static_cast(handle); + exec->Forward(); + API_END(); +} + +int MXExecutorBackward(ExecutorHandle handle, + mx_uint len, + NArrayHandle *head_grads) { + API_BEGIN(); + Executor *exec = static_cast(handle); + std::vector narrays; + NArray **args_ptr = reinterpret_cast(head_grads); + for (mx_uint i = 0; i < len; ++i) { + narrays.push_back(*args_ptr[i]); + } + exec->Backward(narrays); + API_END(); +} + +int MXExecutorHeads(ExecutorHandle handle, + mx_uint *out_size, + NArrayHandle **out) { + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + API_BEGIN(); + Executor *exec = static_cast(handle); + std::vector heads = exec->heads(); + ret->ret_handles.resize(heads.size()); + for (size_t i = 0; i < heads.size(); ++i) { + NArray *ptr = new NArray(); + *ptr = heads[i]; + ret->ret_handles[i] = ptr; + } + *out_size = heads.size(); + *out = dmlc::BeginPtr(ret->ret_handles); + API_END(); +} + +int MXExecutorBind(SymbolHandle symbol_handle, + int dev_mask, + int dev_id, + mx_uint len, + NArrayHandle *in_args, + NArrayHandle *arg_grad_store, + mx_uint *grad_req_type, + ExecutorHandle *out) { + API_BEGIN(); + Symbol *symb = static_cast(symbol_handle); + Context ctx = Context(dev_mask, dev_id); + NArray **in_args_ptr = reinterpret_cast(in_args); + NArray **arg_grad_ptr = reinterpret_cast(arg_grad_store); + std::vector in_args_vec; + std::vector arg_grad_vec; + std::vector grad_req_vec; + for (mx_uint i = 0; i < len; ++i) { + in_args_vec.push_back(*(in_args_ptr[i])); + arg_grad_vec.push_back(*(arg_grad_ptr[i])); + grad_req_vec.push_back(static_cast(grad_req_type[i])); + } + *out = Executor::Bind(*symb, ctx, in_args_vec, arg_grad_vec, grad_req_vec); + API_END(); +} + diff --git a/src/narray/narray.cc b/src/narray/narray.cc index 831041bd1496..3618a38c9d59 100644 --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -37,14 +37,16 @@ inline void BinaryOp(const NArray &lhs, case cpu::kDevMask: DAGEngine::Get()->Push([lhs, rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); - narray::Eval(lhs.ptr_->data, rhs.ptr_->data, &ret.ptr_->data, ctx); + TBlob tmp = ret.data(); + narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); }, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); break; #if MXNET_USE_CUDA case gpu::kDevMask: DAGEngine::Get()->Push([lhs, rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); - narray::Eval(lhs.ptr_->data, rhs.ptr_->data, &ret.ptr_->data, ctx); + TBlob tmp = ret.data(); + narray::Eval(lhs.data(), rhs.data(), &tmp, ctx); }, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); break; #endif @@ -64,14 +66,16 @@ void CopyFromTo(const NArray &from, NArray *to) { if (a == cpu::kDevMask && b == cpu::kDevMask) { DAGEngine::Get()->Push([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); - narray::Copy(from.ptr_->data, &ret.ptr_->data, + TBlob tmp = ret.data(); + narray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); }, from.ctx(), {from.ptr_->var}, {ret.ptr_->var}); } else if (a == cpu::kDevMask && b == gpu::kDevMask) { #if MXNET_USE_CUDA DAGEngine::Get()->Push([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); - narray::Copy(from.ptr_->data, &ret.ptr_->data, + TBlob tmp = ret.data(); + narray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); }, ret.ctx(), {from.ptr_->var}, {ret.ptr_->var}); #else @@ -81,7 +85,8 @@ void CopyFromTo(const NArray &from, NArray *to) { #if MXNET_USE_CUDA DAGEngine::Get()->Push([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); - narray::Copy(from.ptr_->data, &ret.ptr_->data, + TBlob tmp = ret.data(); + narray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); }, from.ctx(), {from.ptr_->var}, {ret.ptr_->var}); #else @@ -91,7 +96,8 @@ void CopyFromTo(const NArray &from, NArray *to) { #if MXNET_USE_CUDA DAGEngine::Get()->Push([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); - narray::Copy(from.ptr_->data, &ret.ptr_->data, + TBlob tmp = ret.data(); + narray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); }, from.ctx(), {from.ptr_->var}, {ret.ptr_->var}); #else diff --git a/src/operator/activation-inl.h b/src/operator/activation-inl.h index 6374d02cc53b..3d57d6a88102 100644 --- a/src/operator/activation-inl.h +++ b/src/operator/activation-inl.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -28,8 +29,8 @@ struct ActivationParam : public dmlc::Parameter { // use int for enumeration int type; DMLC_DECLARE_PARAMETER(ActivationParam) { - // TODO(bing) support enum, str->int mapping - DMLC_DECLARE_FIELD(type).set_default(kReLU); + DMLC_DECLARE_FIELD(type).set_default(kReLU).add_enum("relu", kReLU).\ + add_enum("sigmoid", kSigmoid).add_enum("tanh", kTanh); } }; @@ -115,17 +116,17 @@ class ActivationProp : public OperatorProperty { return {out_grad[kOut], out_data[kOut]}; } - virtual std::vector > BackwardInplaceOption( + virtual std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, - const std::vector &in_grad) const { + const std::vector &in_grad) const { return {{out_grad[kOut], in_grad[kData]}}; } - virtual std::vector > ForwardInplaceOption( + virtual std::vector > ForwardInplaceOption( const std::vector &in_data, - const std::vector &out_data) const { + const std::vector &out_data) const { return {{in_data[kData], out_data[kOut]}}; } diff --git a/src/operator/elementwise_sum-inl.h b/src/operator/elementwise_sum-inl.h index f0a558b3b0cc..4a0d6e3fdd57 100644 --- a/src/operator/elementwise_sum-inl.h +++ b/src/operator/elementwise_sum-inl.h @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -145,17 +146,17 @@ class ElementWiseSumProp : public OperatorProperty { return out_grad; } - virtual std::vector > BackwardInplaceOption( + virtual std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, - const std::vector &in_grad) const { + const std::vector &in_grad) const { return {{out_grad[0], in_grad[0]}}; } - virtual std::vector > ForwardInplaceOption( + virtual std::vector > ForwardInplaceOption( const std::vector &in_data, - const std::vector &out_data) const { + const std::vector &out_data) const { return {{in_data[0], out_data[0]}}; } diff --git a/src/operator/fully_connected-inl.h b/src/operator/fully_connected-inl.h index 9dbb9bda8649..b49e5c422739 100644 --- a/src/operator/fully_connected-inl.h +++ b/src/operator/fully_connected-inl.h @@ -9,11 +9,12 @@ #include #include #include +#include #include #include #include #include "./operator_common.h" -#include "./param.h" + namespace mxnet { namespace op { @@ -121,9 +122,7 @@ class FullyConnectedProp : public OperatorProperty { } virtual void Init(const std::vector >& kwargs) { - // TODO(bing) change directly to vector of pairs begin end - std::map kmap(kwargs.begin(), kwargs.end()); - param_.Init(kmap); + param_.Init(kwargs); } virtual bool InferShape(std::vector *in_shape, @@ -175,12 +174,12 @@ class FullyConnectedProp : public OperatorProperty { return {out_grad[kOut], in_data[kData], in_data[kWeight]}; } - virtual std::vector > BackwardInplaceOption( + virtual std::vector > BackwardInplaceOption( const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data, - const std::vector &in_grad) const { - return {{in_grad[kData], in_data[kData]}}; + const std::vector &in_grad) const { + return {{in_data[kData], in_grad[kData]}}; } Operator* CreateOperator(Context ctx) const; diff --git a/src/operator/pooling-inl.h b/src/operator/pooling-inl.h new file mode 100644 index 000000000000..8b223e2476a2 --- /dev/null +++ b/src/operator/pooling-inl.h @@ -0,0 +1,201 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file pooling-inl.h + * \brief + * \author Bing Xu +*/ + +#ifndef MXNET_OPERATOR_POOLING_INL_H_ +#define MXNET_OPERATOR_POOLING_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" + +namespace mxnet { +namespace op { +enum PoolingOpInputs {kData}; +enum PoolingOpOutputs {kOut}; +enum PoolingOpType {kMaxPooling, kAvgPooling, kSumPooling}; + +struct PoolingParam : public dmlc::Parameter { + int kernel_x; + int kernel_y; + int stride_x; + int stride_y; + int pad_x; + int pad_y; + int type; + DMLC_DECLARE_PARAMETER(PoolingParam) { + // TODO(bing) change to only set lower bound + DMLC_DECLARE_FIELD(kernel_x).set_range(1, 10000); + DMLC_DECLARE_FIELD(kernel_y).set_range(1, 10000); + DMLC_DECLARE_FIELD(stride_x).set_range(1, 10000); + DMLC_DECLARE_FIELD(stride_y).set_range(1, 10000); + DMLC_DECLARE_FIELD(pad_x).set_default(0).set_range(0, 10000); + DMLC_DECLARE_FIELD(pad_y).set_default(0).set_range(0, 10000); + DMLC_DECLARE_FIELD(type).set_default(kMaxPooling)\ + .add_enum("max", kMaxPooling).add_enum("avg", kAvgPooling)\ + .add_enum("sum", kSumPooling); + } +}; + +template +class PoolingOp : public Operator { + public: + explicit PoolingOp(PoolingParam p) { + this->param_ = p; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(req[kOut], kWriteTo); + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + Stream *s = ctx.get_stream(); + Tensor data = in_data[kData].get(s); + Tensor out = out_data[kOut].get(s); + mshadow::Shape<2> out_shape = Shape2(out.shape_[2], out.shape_[3]); + // TODO(bing): dual stride in mshadow + if (param_.type == kMaxPooling || param_.type == kSumPooling) { + out = pool(pad(data, param_.pad_y, param_.pad_x), + out_shape, + param_.kernel_y, + param_.kernel_x, + param_.kernel_y); + } else if (param_.type == kAvgPooling) { + out = (1.0f / (param_.kernel_y * param_.kernel_x)) * \ + pool(pad(data, param_.pad_y, param_.pad_x), + out_shape, + param_.kernel_y, + param_.kernel_x, + param_.kernel_y); + } + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1); + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + CHECK_EQ(req.size(), 1); + CHECK_EQ(in_grad.size(), 1); + // TODO(bing): remove pad (0,0) + Stream *s = ctx.get_stream(); + Tensor grad = out_grad[kOut].get(s); + Tensor data = in_data[kData].get(s); + Tensor output_data = out_data[kOut].get(s); + Tensor input_grad = in_grad[kData].get(s); + + mshadow::Shape<2> in_shape = Shape2(data.shape_[2], data.shape_[3]); + + if (param_.type == kMaxPooling || param_.type == kSumPooling) { + Assign(input_grad, req[kData], + crop(unpool(pad(data, param_.pad_y, param_.pad_x), + pad(output_data, 0, 0), + pad(grad, 0, 0), + param_.kernel_y, + param_.kernel_x, + param_.stride_y), + in_shape, + param_.pad_y, + param_.pad_x)); + } else if (param_.type == kAvgPooling) { + Assign(input_grad, req[kData], + (1.0f / param_.kernel_y / param_.kernel_x) *\ + crop(unpool(pad(data, param_.pad_y, param_.pad_x), + pad(output_data, 0, 0), + pad(grad, 0, 0), + param_.kernel_y, + param_.kernel_x, + param_.stride_y), + in_shape, + param_.pad_y, + param_.pad_x)); + } + } + + private: + PoolingParam param_; +}; // class PoolingOp + +template +Operator* CreateOp(PoolingParam param); + + +#if DMLC_USE_CXX11 +class PoolingProp : public OperatorProperty { + public: + virtual void Init(const std::vector >& kwargs) { + param_.Init(kwargs); + } + + virtual bool InferShape(std::vector *in_shape, + std::vector *out_shape) const { + CHECK_EQ(in_shape->size(), 1); + const TShape &dshape = (*in_shape)[0]; + CHECK_EQ(dshape.ndim(), 4) << \ + "Pooling: Input data should be 4D in (batch, channel, y, x)"; + TShape oshape = dshape; + if (dshape.ndim() == 0) return false; + oshape[2] = std::min(dshape[2] + 2 * param_.pad_y - param_.kernel_y + param_.stride_y - 1, + dshape[2] + 2 * param_.pad_y - 1) / param_.stride_y + 1; + oshape[3] = std::min(dshape[3] + 2 * param_.pad_x - param_.kernel_x + param_.stride_x - 1, + dshape[3] + 2 * param_.pad_x - 1) / param_.stride_x + 1; + CHECK(oshape[2] > 0 && oshape[3] > 0) << "Pooling: kernel size exceed input"; + out_shape->clear(); + out_shape->push_back(oshape); + return true; + } + + virtual OperatorProperty* Copy() const { + PoolingProp *prop_sym = new PoolingProp(); + prop_sym->param_ = this->param_; + return prop_sym; + } + + virtual std::string TypeString() const { + return "Pooling"; + } + + virtual std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const { + return {out_grad[kOut], in_data[kData], out_data[kOut]}; + } + + virtual std::vector > BackwardInplaceOption( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &in_grad) const { + return {{in_data[kData], in_grad[kData]}}; + } + + Operator* CreateOperator(Context ctx) const; + + private: + PoolingParam param_; +}; // class PoolingProp +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_POOLING_INL_H_ diff --git a/src/operator/pooling.cc b/src/operator/pooling.cc new file mode 100644 index 000000000000..a6ebc91e0873 --- /dev/null +++ b/src/operator/pooling.cc @@ -0,0 +1,34 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file pooling.cc + * \brief + * \author Bing Xu +*/ + +#include +#include "./pooling-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(PoolingParam param) { + switch (param.type) { + case kMaxPooling: return new PoolingOp(param); + case kAvgPooling: return new PoolingOp(param); + case kSumPooling: return new PoolingOp(param); + default: + LOG(FATAL) << "unknown activation type"; + return NULL; + } +} + +Operator* PoolingProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp, param_); +} + +DMLC_REGISTER_PARAMETER(PoolingParam); + +REGISTER_OP_PROPERTY(Pooling, PoolingProp); +} // namespace op +} // namespace mxnet + diff --git a/src/operator/pooling.cu b/src/operator/pooling.cu new file mode 100644 index 000000000000..2db6d9ea549a --- /dev/null +++ b/src/operator/pooling.cu @@ -0,0 +1,26 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file pooling.cu + * \brief + * \author Bing Xu +*/ + +#include "./pooling-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(PoolingParam param) { + switch (param.type) { + case kMaxPooling: return new PoolingOp(param); + case kAvgPooling: return new PoolingOp(param); + case kSumPooling: return new PoolingOp(param); + default: + LOG(FATAL) << "unknown activation type"; + return NULL; + } +} + +} // namespace op +} // namespace mxnet + diff --git a/src/operator/static_operator/pooling_op-inl.h b/src/operator/static_operator/pooling_op-inl.h deleted file mode 100644 index 8c6014a8c2cf..000000000000 --- a/src/operator/static_operator/pooling_op-inl.h +++ /dev/null @@ -1,153 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file pooling_op-inl.h - * \brief pooling operator - * \author Bing Xu -*/ -#ifndef MXNET_OPERATOR_STATIC_OPERATOR_POOLING_OP_INL_H_ -#define MXNET_OPERATOR_STATIC_OPERATOR_POOLING_OP_INL_H_ - -#include -#include -#include -#include "./param.h" -#include "./static_operator_common.h" - - -namespace mxnet { -namespace op { -template -class PoolingOp : public StaticOperator { - public: - virtual void SetParam(const char *name, const char *val) { - param_.SetParam(name, val); - } - virtual void InferShape(std::vector *in_shape, - std::vector *out_shape) { - CHECK_EQ(in_shape->size(), 1) << "Input: [data]"; - CHECK_GT(param_.kernel_y, 0); - CHECK_GT(param_.kernel_x, 0); - const int ksize_y = static_cast(param_.kernel_y); - const int ksize_x = static_cast(param_.kernel_x); - const int pad_y = static_cast(param_.pad_y); - const int pad_x = static_cast(param_.pad_x); - // TODO(bing): dual stride - const int kstride = static_cast(param_.stride_y); - mshadow::Shape<4> ishape = (*in_shape)[0].get<4>(); - oshape_ = ishape; - fea_shape_ = mshadow::Shape2(ishape[2], ishape[3]); - oshape_[2] = std::min(ishape[2] + 2 * pad_y - ksize_y + kstride - 1, - ishape[2] + 2 * pad_y - 1) / kstride + 1; - oshape_[3] = std::min(ishape[3] + 2 * pad_x - ksize_x + kstride - 1, - ishape[3] + 2 * pad_x - 1) / kstride + 1; - CHECK(oshape_[2] > 0 && oshape_[3] > 0) << "kernel size exceed input"; - out_shape->clear(); - out_shape->push_back((*in_shape)[0]); - (*out_shape)[0][2] = oshape_[2]; - (*out_shape)[0][3] = oshape_[3]; - } - virtual void Forward(Option opt, - RunContext ctx, - const std::vector &in_data, - const std::vector &out_data) { - CHECK_EQ(in_data.size(), 1); - CHECK_EQ(out_data.size(), 0); - if (!(temp_.shape_ == oshape_)) { - temp_.Resize(oshape_); - } - const int ksize_y = param_.kernel_y; - const int ksize_x = param_.kernel_x; - const int pad_y = param_.pad_y; - const int pad_x = param_.pad_x; - // TODO(bing): dual stride - const int kstride = param_.stride_y; - using namespace mshadow; - using namespace mshadow::expr; - Stream *s = static_cast *>(ctx.stream); - Tensor data = in_data[0].get(s); - Tensor out = out_data[0].get(s); - mshadow::Shape<2> pshape = Shape2(out.shape_[2], out.shape_[3]); - if (mode == kMaxPooling || mode == kSumPooling) { - temp_ = pool(pad(data, pad_y, pad_x), - pshape, - ksize_y, - ksize_x, - kstride); - } else if (mode == kAvgPooling) { - temp_ = (1.0f / (ksize_y * ksize_x)) * \ - pool(pad(data, pad_y, pad_x), - pshape, - ksize_y, - ksize_x, - kstride); - } else { - LOG(FATAL) << "Unknown pooling mode"; - } - Copy(out, temp_, s); - } - virtual void Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &out_grad, - const std::vector &req) { - CHECK_EQ(grad_next.size(), 1); - CHECK_EQ(in_data.size(), 1); - CHECK_EQ(out_grad.size(), 1); - CHECK_EQ(req.size(), 1); - const int ksize_y = param_.kernel_y; - const int ksize_x = param_.kernel_x; - const int pad_y = param_.pad_y; - const int pad_x = param_.pad_x; - // TODO(bing): dual stride - const int kstride = param_.stride_y; - using namespace mshadow; - using namespace mshadow::expr; - Stream *s = static_cast *>(ctx.stream); - Tensor grad = grad_next[0].get(s); - Tensor data = in_data[0].get(s); - Tensor out = out_grad[0].get(s); - if (mode == kMaxPooling || mode == kSumPooling) { - Assign(out, - req[0], - crop(unpool(pad(data, pad_y, pad_x), - pad(temp_, 0, 0), - pad(grad, 0, 0), - ksize_y, - ksize_x, - kstride), - fea_shape_, - pad_y, - pad_x)); - } else if (mode == kAvgPooling) { - Assign(out, - req[0], - (1.0f / (ksize_y * ksize_x)) * \ - crop(unpool(pad(data, pad_y, pad_x), - pad(temp_, 0, 0), - pad(grad, 0, 0), - ksize_y, - ksize_x, - kstride), - fea_shape_, - pad_y, - pad_x)); - } else { - LOG(FATAL) << "Unknown pooling mode"; - } - } - - private: - /*! \brief parameters that potentially be useful */ - Param param_; - /*! \brief temp space to save pooled result */ - mshadow::TensorContainer temp_; - /*! \brief pooled output shape */ - mshadow::Shape<4> oshape_; - /*! \brief input feature map shape */ - mshadow::Shape<2> fea_shape_; -}; // class PoolingOp - -} // namespace op -} // namespace mxnet -#endif // MXNET_OPERATOR_STATIC_OPERATOR_POOLING_OP_INL_H_ diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc new file mode 100644 index 000000000000..a434f22a2fc6 --- /dev/null +++ b/src/symbol/graph_executor.cc @@ -0,0 +1,496 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file graph_executor.cc + * \brief Executor to execute the Graph. +*/ +#include +#include +#include +#include "./graph_executor.h" + +namespace mxnet { +/*! + * \brief wrapper class that wraps Backward operation as Forward. + */ +class GraphExecutor::BackwardOpWrapper : public Operator { + public: + /*! + * \brief create a backward Operator wrapper given forward op. + * \param prop pointer to the property of forward wrapper + * \param forward_op the shared ptr to Forward operator + * \return the created wrapper. + */ + explicit BackwardOpWrapper(const OperatorProperty *prop, + std::shared_ptr forward_op) + : op_(forward_op) { + out_grad_.resize(prop->NumVisibleReturns()); + in_data_.resize(prop->ListArguments().size()); + out_data_.resize(prop->NumReturns()); + + std::vector out_grad_ptr(out_grad_.size()); + for (size_t i = 0; i < out_grad_.size(); ++i) { + out_grad_ptr[i] = &out_grad_[i]; + } + std::vector in_data_ptr(in_data_.size()); + for (size_t i = 0; i < in_data_.size(); ++i) { + in_data_ptr[i] = &in_data_[i]; + } + std::vector out_data_ptr(out_data_.size()); + for (size_t i = 0; i < out_data_.size(); ++i) { + out_data_ptr[i] = &out_data_[i]; + } + arg_data_ptr_ = prop->BackwardInputs( + out_grad_ptr, in_data_ptr, out_data_ptr); + } + // implement forward + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + // set things correctly + CHECK(arg_data_ptr_.size() == in_data.size()); + for (size_t i = 0; i < in_data.size(); ++i) { + *(arg_data_ptr_[i]) = in_data[i]; + } + // redirect internally + op_->Backward(ctx, out_grad_, in_data_, out_data_, req, out_data); + } + + private: + /*! \brief internal forward operator */ + std::shared_ptr op_; + /*! \brief internal space for out_grad */ + std::vector out_grad_; + /*! \brief internal space for in_data */ + std::vector in_data_; + /*! \brief internal space for out_data */ + std::vector out_data_; + /*! + * \brief pointer to places in the internal space. + * arg_data_ptr_ maps in_data in Forward to the internal space. + */ + std::vector arg_data_ptr_; +}; + +// get resource +inline std::vector +GraphExecutor::GetResource(uint32_t node_id) const { + const StaticGraph::Node &node = graph_.nodes[node_id]; + if (node.is_forward()) { + return node.op->ForwardResource(); + } else { + CHECK(node.is_backward()); + return graph_.nodes[node.backward_source_id].op->BackwardResource(); + } +} + +inline int GraphExecutor::GetNumOutputs(uint32_t node_id) const { + const StaticGraph::Node &node = graph_.nodes[node_id]; + if (node.is_forward()) { + return node.op->NumReturns(); + } else if (node.is_backward()) { + return static_cast( + graph_.nodes[node.backward_source_id].op->ListArguments().size()); + } else { + CHECK(node.is_variable()); + return 1; + } +} + +// implement get input option +template +inline std::vector > GraphExecutor::GetInplaceOption( + uint32_t node_id, + const std::vector &in_data, + const std::vector &out_data) const { + // get the node + const StaticGraph::Node &node = graph_.nodes[node_id]; + + if (node.is_forward()) { + std::vector in_data_index(in_data.size()); + for (size_t i = 0; i < in_data.size(); ++i) { + in_data_index[i] = static_cast(i); + } + std::vector out_data_ptr(out_data.size()); + for (size_t i = 0; i < out_data.size(); ++i) { + out_data_ptr[i] = (void*)&out_data[i]; // NOLINT(*) + } + auto rmap_index = node.op->ForwardInplaceOption(in_data_index, out_data_ptr); + std::vector > remap(rmap_index.size()); + for (size_t i = 0; i < remap.size(); ++i) { + remap[i].first = in_data[rmap_index[i].first]; + remap[i].second = *static_cast(rmap_index[i].second); + } + return std::move(remap); + } else { + CHECK(node.is_backward()); + // forward property + const OperatorProperty *fwd = graph_.nodes[node.backward_source_id].op.get(); + + std::vector out_grad_index(fwd->NumVisibleReturns()); + std::vector in_data_index(fwd->ListArguments().size()); + std::vector out_data_index(fwd->NumReturns()); + CHECK_EQ(in_data_index.size(), out_data.size()); + int counter = 0; + for (size_t i = 0; i < out_grad_index.size(); ++i) { + out_grad_index[i] = counter++; + } + for (size_t i = 0; i < in_data_index.size(); ++i) { + in_data_index[i] = counter++; + } + for (size_t i = 0; i < out_data_index.size(); ++i) { + out_data_index[i] = counter++; + } + auto args_index = fwd->DeclareBackwardDependency( + out_grad_index, in_data_index, out_data_index); + std::vector args_array(counter, nullptr); + CHECK_EQ(args_index.size(), in_data.size()); + for (size_t i = 0; i < in_data.size(); ++i) { + args_array[args_index[i]] = &in_data[i]; + } + std::vector in_grad_ptr(out_data.size()); + for (size_t i = 0; i < in_grad_ptr.size(); ++i) { + in_grad_ptr[i] = (void*)&out_data[i]; // NOLINT(*) + } + auto remap_index = fwd->BackwardInplaceOption( + out_grad_index, in_data_index, out_data_index, in_grad_ptr); + std::vector > remap(remap_index.size()); + for (size_t i = 0; i < remap_index.size(); ++i) { + CHECK_NE(args_array[remap_index[i].first], nullptr) + << "BackwardInplaceOption uses input that is returned by DeclareBackwardDependency"; + remap[i].first = *args_array[remap_index[i].first]; + remap[i].second = *static_cast(remap_index[i].second); + } + return std::move(remap); + } +} + +inline GraphExecutor::OpExecEntry +GraphExecutor::GetOpExecEntry(uint32_t nid) { + OpNode& op_node = op_nodes_[nid]; + Operator *op = op_node.op.get(); + std::vector req; + std::vector in_data, out_data; + in_data.reserve(graph_.nodes[nid].inputs.size()); + out_data.reserve(op_node.outputs.size()); + req.reserve(op_node.outputs.size()); + + OpExecEntry exec; + for (const DataEntryInfo& out : op_node.outputs) { + out_data.push_back(out.data.data()); + exec.mutate_vars.push_back(out.data.var()); + req.push_back(out.op_req); + } + + for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) { + const DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; + in_data.push_back(info.data.data()); + // skip inplace since they already appear in mutate vars + if (info.inplace_op_id != static_cast(nid)) { + exec.use_vars.push_back(info.data.var()); + } + } + + OpContext* op_ctx_ptr = &op_node.op_ctx; + exec.exec_fun = [op, op_ctx_ptr, in_data, req, out_data] (RunContext ctx) { + op_ctx_ptr->run_ctx = ctx; + op->Forward(*op_ctx_ptr, in_data, req, out_data); + }; + return std::move(exec); +} + +void GraphExecutor::InitGraph(Symbol symbol, Context ctx, bool need_backward) { + // initialize all internal daa structures + symbol.ToStaticGraph(&graph_); + num_forward_nodes_ = graph_.nodes.size(); + if (need_backward) { + graph_.MakeBackwardPass(&head_grad_nodes_, &arg_grads_); + } + // reorganize so backward node always follow forward + // note that this may not be the case, because existence of head_grad_nodes + std::vector topo = graph_.TopoSort(); + std::vector backward; + for (uint32_t nid : topo) { + if (nid < num_forward_nodes_) { + topo_order_.push_back(nid); + } else { + backward.push_back(nid); + } + } + topo_order_.insert(topo_order_.end(), backward.begin(), backward.end()); + // setup all the operator nodes data structure + op_nodes_.resize(graph_.nodes.size()); + for (size_t i = 0; i < graph_.nodes.size(); ++i) { + op_nodes_[i].ctx = ctx; + op_nodes_[i].outputs.resize(GetNumOutputs(i)); + } +} + +void GraphExecutor::InitDataEntryInfo(const std::vector &in_args, + const std::vector &arg_grad_store, + const std::vector &grad_req_type) { + CHECK_EQ(arg_grad_store.size(), grad_req_type.size()); + CHECK_EQ(in_args.size(), graph_.arg_nodes.size()); + // bind inputs + for (size_t i = 0; i < graph_.arg_nodes.size(); ++i) { + DataEntryInfo &info = op_nodes_[graph_.arg_nodes[i]].outputs[0]; + info.type = kBindByExternal; + info.data = in_args[i]; + } + // setup ref for head nodes + for (StaticGraph::DataEntry e : graph_.heads) { + DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; + ++info.ref_count; + op_nodes_[e.source_id].activated = true; + } + // need Backward pass + if (arg_grads_.size() != 0) { + CHECK_EQ(arg_grads_.size(), arg_grad_store.size()); + CHECK_EQ(arg_grads_.size(), grad_req_type.size()); + // setup gradient placeholders + for (size_t i = 0; i < arg_grads_.size(); ++i) { + if (grad_req_type[i] == kNullOp) continue; + CHECK_NE(grad_req_type[i], kWriteInplace) + << "Gradient request can only be nullop, add, write"; + std::vector &grad_source = arg_grads_[i]; + CHECK_GE(grad_source.size(), 1); + // TODO(bing) add a aggregation node here + if (grad_source.size() > 1) { + CHECK_EQ(grad_req_type[i], kAddTo) + << "The gradient contains multiple variables,"; + } + for (StaticGraph::DataEntry e : grad_source) { + DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; + info.type = kBindByExternal; + info.op_req = grad_req_type[i]; + info.data = arg_grad_store[i]; + ++info.ref_count; + op_nodes_[e.source_id].activated = true; + } + } + // setup head gradient + for (uint32_t nid : head_grad_nodes_) { + DataEntryInfo &info = op_nodes_[nid].outputs[0]; + info.type = kTobeBindByExternal; + } + } + // update ref counters for all other nodes, in reverse topo order + for (auto it = topo_order_.rbegin(); it != topo_order_.rend(); ++it) { + uint32_t nid = *it; + if (op_nodes_[nid].activated) { + for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) { + DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; + ++info.ref_count; + op_nodes_[e.source_id].activated = true; + } + } + } + + // shape inference + std::vector > out_shapes(op_nodes_.size()); + for (size_t i = 0; i < out_shapes.size(); ++i) { + out_shapes[i].resize(op_nodes_[i].outputs.size()); + } + for (size_t i = 0; i < graph_.arg_nodes.size(); ++i) { + out_shapes[graph_.arg_nodes[i]][0] = in_args[i].shape(); + } + CHECK(graph_.InferNodeShapes(topo_order_, &out_shapes)) + << "Shape inference cannot be complete in bind"; + for (size_t i = 0; i < out_shapes.size(); ++i) { + for (size_t j = 0; j < out_shapes[i].size(); ++j) { + op_nodes_[i].outputs[j].shape = out_shapes[i][j]; + } + } +} + +void GraphExecutor::InitDataEntryMemory() { + // use allocator to allocate memory. + GraphStorageAllocator allocator(&graph_); + for (size_t i = 0; i < topo_order_.size(); ++i) { + uint32_t nid = topo_order_[i]; + if (!op_nodes_[nid].activated) continue; + if (graph_.nodes[nid].is_variable()) continue; + + // check inplace option + std::vector in_data; + in_data.reserve(graph_.nodes[nid].inputs.size()); + // check inputs are ready. + for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) { + DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; + CHECK_NE(info.type, kNotInitialized); + CHECK_NE(info.ref_count, 0); + in_data.push_back(&info); + } + std::vector out_data(op_nodes_[nid].outputs.size()); + for (size_t i = 0; i < op_nodes_[nid].outputs.size(); ++i) { + out_data[i] = &op_nodes_[nid].outputs[i]; + CHECK_NE(out_data[i]->type, kInternalAllocated); + } + auto inplace = GetInplaceOption(nid, in_data, out_data); + + for (std::pair kv : inplace) { + DataEntryInfo* in = kv.first; + DataEntryInfo* out = kv.second; + if (in->ref_count == 1 && + in->type == kInternalAllocated && + out->type == kNotInitialized) { + // we can only do inplace if we are last user of in + // and out is not initialized. + out->type = kInternalAllocated; + out->op_req = kWriteInplace; + out->storage_id = in->storage_id; + // set inplace op id + in->ref_count = 0; + in->inplace_op_id = static_cast(nid); + } + } + // allocate output, + for (DataEntryInfo *out : out_data) { + if (out->op_req == kNullOp && out->ref_count != 0) { + out->op_req = kWriteTo; + } + if (out->type == kNotInitialized) { + out->storage_id = allocator.Request( + op_nodes_[nid].ctx, out->shape, nid); + out->type = kInternalAllocated; + } + } + // then free inputs + for (DataEntryInfo *in : in_data) { + // ref_count == 0 means it is taken by inplace op + if (in->ref_count == 0) { + CHECK_EQ(in->inplace_op_id, static_cast(nid)); + continue; + } + // if we decrease it to zero, means we are ready to relase + --in->ref_count; + if (in->ref_count == 0 && in->type == kInternalAllocated) { + allocator.Release(in->storage_id, nid); + } + } + // check out again, if there is ref_count == 0, release it + for (DataEntryInfo *out : out_data) { + if (out->ref_count == 0 && out->type == kInternalAllocated) { + allocator.Release(out->storage_id, nid); + } + } + } + // one pass complete, allocate real memory + allocator.InitStorages(); + // get the real data NArray into the DataEntryInfo + for (size_t i = 0; i < topo_order_.size(); ++i) { + uint32_t nid = topo_order_[i]; + if (!op_nodes_[nid].activated) continue; + for (DataEntryInfo &out : op_nodes_[nid].outputs) { + CHECK_NE(out.type, kNotInitialized); + if (out.type == kInternalAllocated) { + out.data = allocator.Get(out.storage_id, out.shape); + } + } + } + for (StaticGraph::DataEntry e : graph_.heads) { + DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; + CHECK_EQ(info.type, kInternalAllocated); + heads_narray_.push_back(info.data); + } +} + +void GraphExecutor::InitOpNodes() { + for (size_t i = 0; i < topo_order_.size(); ++i) { + uint32_t nid = topo_order_[i]; + if (!op_nodes_[nid].activated) continue; + if (graph_.nodes[nid].is_variable()) continue; + OpNode& op_node = op_nodes_[nid]; + if (graph_.nodes[nid].is_forward()) { + op_node.op.reset(graph_.nodes[nid].op->CreateOperator(op_node.ctx)); + } else { + CHECK(graph_.nodes[nid].is_backward()); + op_node.op.reset(new BackwardOpWrapper( + graph_.nodes[graph_.nodes[nid].backward_source_id].op.get(), + op_nodes_[graph_.nodes[nid].backward_source_id].op)); + } + bool allow_cache = true; + for (StaticGraph::DataEntry e : graph_.nodes[nid].inputs) { + DataEntryInfo& info = op_nodes_[e.source_id].outputs[e.index]; + if (info.type == kTobeBindByExternal) allow_cache = false; + } + for (DataEntryInfo& info : op_node.outputs) { + if (info.type == kTobeBindByExternal) allow_cache = false; + } + if (allow_cache) { + op_node.cached_exec = GetOpExecEntry(nid); + } + } +} + +void GraphExecutor::RunOps(size_t topo_start, size_t topo_end) { + for (size_t i = topo_start; i < topo_end; ++i) { + uint32_t nid = topo_order_[i]; + if (!op_nodes_[nid].activated) continue; + if (graph_.nodes[nid].is_variable()) continue; + OpNode& opnode = op_nodes_[nid]; + if (opnode.cached_exec.exec_fun != nullptr) { + DAGEngine::Get()->Push( + opnode.cached_exec.exec_fun, + opnode.ctx, + opnode.cached_exec.use_vars, + opnode.cached_exec.mutate_vars); + } else { + auto exec = GetOpExecEntry(nid); + DAGEngine::Get()->Push( + exec.exec_fun, + opnode.ctx, + exec.use_vars, + exec.mutate_vars); + } + } +} + +std::string GraphExecutor::DebugStr() const { + std::ostringstream os; + os << "num_forward_nodes=" << num_forward_nodes_ << '\n'; + for (size_t i = 0; i < topo_order_.size(); ++i) { + uint32_t nid = topo_order_[i]; + if (!op_nodes_[nid].activated) continue; + os << "Op " << i << ":" << graph_.nodes[nid].name << '\n'; + for (size_t j = 0; j < op_nodes_[nid].outputs.size(); ++j) { + const DataEntryInfo &info = op_nodes_[nid].outputs[j]; + os << "\toutput[" << j << "]: shape=" << info.shape; + if (info.storage_id != GraphStorageAllocator::kBadStorageID) { + os << ", storage_id=" << info.storage_id; + } + if (info.inplace_op_id != -1) { + os << ", inplace_consumer=" << graph_.nodes[info.inplace_op_id].name; + } + os << '\n'; + } + } + return os.str(); +} + +void GraphExecutor::Forward() { + RunOps(0, num_forward_nodes_); +} + +void GraphExecutor::Backward(const std::vector &head_grads) { + CHECK_EQ(head_grad_nodes_.size(), head_grads.size()); + for (size_t i = 0; i < head_grad_nodes_.size(); ++i) { + uint32_t nid = head_grad_nodes_[i]; + CHECK(graph_.nodes[nid].is_variable()); + DataEntryInfo &info = op_nodes_[nid].outputs[0]; + CHECK_EQ(info.type, kTobeBindByExternal); + info.data = head_grads[i]; + } + RunOps(num_forward_nodes_, topo_order_.size()); +} + +Executor *Executor::Bind(Symbol symbol, + Context ctx, + const std::vector &in_args, + const std::vector &arg_grad_store, + const std::vector &grad_req_type) { + GraphExecutor *exec = new GraphExecutor(); + exec->Init(symbol, ctx, in_args, arg_grad_store, grad_req_type); + return exec; +} +} // namespace mxnet diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h new file mode 100644 index 000000000000..a072eee69b68 --- /dev/null +++ b/src/symbol/graph_executor.h @@ -0,0 +1,186 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file graph_executor.h + * \brief Executor to execute the Forward and Backward on Composition Graph. +*/ +#ifndef MXNET_SYMBOL_GRAPH_EXECUTOR_H_ +#define MXNET_SYMBOL_GRAPH_EXECUTOR_H_ + +#include +#include +#include +#include +#include "./graph_memory_allocator.h" + +namespace mxnet { +/*! + * \brief Executor of a computation graph. + */ +class GraphExecutor : public Executor { + public: + virtual ~GraphExecutor() {} + virtual void Forward(); + virtual void Backward(const std::vector &head_grads); + virtual const std::vector &heads() const { + return heads_narray_; + } + // implement Executor::Bind, only call it once. + inline void Init(Symbol symbol, + Context ctx, + const std::vector &in_args, + const std::vector &arg_grad_store, + const std::vector &grad_req_type) { + CHECK_EQ(grad_req_type.size(), arg_grad_store.size()); + bool need_backward = false; + for (auto req : grad_req_type) { + if (req != kNullOp) need_backward = true; + } + this->InitGraph(symbol, ctx, need_backward); + this->InitDataEntryInfo(in_args, arg_grad_store, grad_req_type); + this->InitDataEntryMemory(); + this->InitOpNodes(); + // TODO(bing): remove me when things are OK + LOG(INFO) << "-----Execution memory plan-----\n" + << DebugStr() << '\n' + << "------------------------------\n"; + } + + protected: + // internal class of wrapping BackwardOp as ForwardOp + class BackwardOpWrapper; + // type of data entry + enum DataEntryType { + // memory is binded by external NArray in Bind + kBindByExternal, + // to be binded by external NArray in Forward and Backward + kTobeBindByExternal, + // internal memory, allocated + kInternalAllocated, + // internal memory, to be allocated + kNotInitialized + }; + // Additional information about each data entry + struct DataEntryInfo { + // the actual data for the entry + NArray data; + // write request to this entry + OpReqType op_req; + // the operatio node that will take + // this DataEntry as inplace input + int inplace_op_id; + // data entry type + DataEntryType type; + // shape of this entry + TShape shape; + // storage id from allocator if it is internal allocation. + GraphStorageAllocator::StorageID storage_id; + // reference count on how many times this entry is being used. + // That is how many operators and heads need this DataEntry + // this is a temporal variable that is used during initialization. + uint32_t ref_count; + // constructor + DataEntryInfo() + : op_req(kNullOp), + inplace_op_id(-1), + type(kNotInitialized), + storage_id(GraphStorageAllocator::kBadStorageID), + ref_count(0) {} + }; + // all the information needed to push the op to engine + struct OpExecEntry { + // execution function for + DAGEngine::Op exec_fun; + // variables to read from + std::vector use_vars; + // variables to mutate + std::vector mutate_vars; + // constructor + OpExecEntry() : exec_fun(nullptr) {} + }; + // Information about operational node + struct OpNode { + // whether this op node is activated + bool activated; + // the context of the node + Context ctx; + // data entry information about outputs of op + std::vector outputs; + // The following parts are constructed in InitOpNodes + // the real operator + std::shared_ptr op; + // op context, that is defined for this op. + OpContext op_ctx; + // executor, this is only allocated for nodes + // whose inputs, outputs are pre-defined. + // otherwise cached_exec.exec_fun == nullptr + OpExecEntry cached_exec; + // constructor + OpNode() : activated(false) {} + }; + /*! + * \brief Get input option of a node. + * This function is overriden for both Forward and Backward node. + * + * \param node_id node index of node in StaticGraph + * \param in_data the input data entry to the node + * \param out_data the output data entry in the graph + * \return the paired inplace option. + */ + template + inline std::vector > GetInplaceOption( + uint32_t node_id, + const std::vector &in_data, + const std::vector &out_data) const; + /*! + * \brief Get resource requirement of a node. + * This function is overriden for both Forward and Backward node. + * \param node_id node index of node in StaticGraph + * \return the desired resource request. + */ + inline std::vector GetResource(uint32_t node_id) const; + /*! + * \brief Get number of outputs of a node. + * This function is overriden for both Forward and Backward node. + * \param node_id node index of node in StaticGraph + * \return the number of outputs of the node. + */ + inline int GetNumOutputs(uint32_t node_id) const; + /*! + * \brief get execution entry for an OpNode. + * This function can only be called after initialization is done. + * \param node_id the id of operational node. + * \return the execution entry. + */ + inline OpExecEntry GetOpExecEntry(uint32_t node_id); + // initialize the internal graph structure + void InitGraph(Symbol symbol, Context ctx, bool need_backward); + // initialize internal DataEntryInfo, reference counting + void InitDataEntryInfo(const std::vector &in_args, + const std::vector &arg_grad_store, + const std::vector &grad_req_type); + // initialize internal data entries NArray + void InitDataEntryMemory(); + // initialize OpNode data structure + void InitOpNodes(); + // run ops from topo order start to end + void RunOps(size_t topo_start, size_t topo_end); + // get debug string + std::string DebugStr() const; + // internal computational graph + StaticGraph graph_; + // topological order of nodes in computation graph + // backward nodes always follow forward nodes + std::vector topo_order_; + // number of forward nodes in the graph + size_t num_forward_nodes_; + // head gradient node in the graph, if there is backward pass + std::vector head_grad_nodes_; + // argument node in the graph, if there is backward pass + std::vector > arg_grads_; + // operational nodes + std::vector op_nodes_; + // head NArrays + std::vector heads_narray_; +}; // class GraphExecutor +} // namespace mxnet +#endif // MXNET_SYMBOL_GRAPH_EXECUTOR_H_ diff --git a/src/symbol/graph_memory_allocator.h b/src/symbol/graph_memory_allocator.h new file mode 100644 index 000000000000..b7bd2db2081e --- /dev/null +++ b/src/symbol/graph_memory_allocator.h @@ -0,0 +1,145 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file graph_memory_allocator.h + * \brief Memory allocator for graph executor. +*/ +#ifndef MXNET_SYMBOL_GRAPH_MEMORY_ALLOCATOR_H_ +#define MXNET_SYMBOL_GRAPH_MEMORY_ALLOCATOR_H_ + +#include +#include +#include +#include + +namespace mxnet { +/*! + * \brief Memory allocators for the GraphExecutor. + * This class is intended to be used by GraphExecutor + * to allocate the memory for each DataEntryInfo. + * + * The class algorithm works in two phase: + * (1) Planning Phase: GraphExecutor call Request and Release + * to request and release resources according to dependency. + * - Each call to Request will get a ResourceID that is used to + * identify the memory block assigned to each DataEntryInfo. + * (2) Allocating phase: GraphExecutor call InitMemory. + * - Then each DataEntry will call Get to get the real NArray. + * (3) All the memory will be freed up when reference to all the related NArray ends. + */ +class GraphStorageAllocator { + public: + /*! \brief resource index */ + typedef int64_t StorageID; + /*! \brief bad storage id */ + static const StorageID kBadStorageID = -1; + /*! \brief constructor to the graph memory allocator */ + explicit GraphStorageAllocator(StaticGraph *graph); + /*! + * \brief Request a memory. + * \param ctx the context of the graph + * \param shape shape of the NArray we want + * \param node_id the node that is requesting the memory, used as hint. + */ + StorageID Request(Context ctx, TShape shape, uint32_t node_id); + /*! + * \brief Release a memory. + * \param id the storage ID of the memory. + * \param node_id the node id in the graph that is releasing the memory. + */ + void Release(StorageID id, uint32_t node_id); + /*! \brief Initialize all the memories requested */ + void InitStorages(); + /*! + * \brief Get the the memory allocated in planning phase. + * \param id the storage id allocated in planning phase. + * \param shape the shape of the NArray requested. + */ + NArray Get(StorageID id, TShape shape); + + private: + /*! \brief internal storage entry */ + struct StorageEntry { + /*! \brief id of the storage */ + StorageID id; + /*! \brief the context of the storage */ + Context ctx; + /*! \brief maximum size of the storage that is requested */ + size_t max_size; + /*! \brief the actual NArray to hold the data */ + NArray data; + /*! \brief constructor */ + StorageEntry() : max_size(0) {} + }; + /*! + * \brief Allocate a StorageID when Request cannot found existing ones. + * \param ctx the context of the graph + * \param shape shape of the NArray we want + */ + StorageID Alloc(Context ctx, size_t size); + + /*! \brief reference to the computation graph */ + StaticGraph *graph_; + /*! \brief all the resources available */ + std::vector > data_; + /*! + * \brief free list of storage entries, maps size to free list + */ + std::multimap free_; +}; + +// put implementation in header files for now +GraphStorageAllocator::GraphStorageAllocator(StaticGraph *graph) + : graph_(graph) {} + +GraphStorageAllocator::StorageID +GraphStorageAllocator::Alloc(Context ctx, size_t size) { + StorageID id = static_cast(data_.size()); + std::unique_ptr ptr(new StorageEntry()); + ptr->id = id; + ptr->ctx = ctx; + ptr->max_size = size; + data_.push_back(std::move(ptr)); + return id; +} + +GraphStorageAllocator::StorageID +GraphStorageAllocator::Request(Context ctx, TShape shape, uint32_t node_id) { + size_t size = shape.Size(); + auto begin = free_.lower_bound(size); + auto end = free_.upper_bound(size); + // vector of possible candidates + for (auto it = begin; it != end; ++it) { + StorageEntry *e = it->second; + if (e->ctx != ctx) continue; + // Use exect matching strategy + // TODO(bing): think of other strategies, for example, rough match. + if (e->max_size != size) continue; + // find a exact match, erase from map and return + free_.erase(it); + return e->id; + } + // cannot find anything return a new one. + return this->Alloc(ctx, size); +} + +void GraphStorageAllocator::Release(StorageID id, uint32_t node_id) { + CHECK_NE(id, kBadStorageID); + StorageEntry *e = data_[id].get(); + free_.insert({e->max_size, e}); +} + +void GraphStorageAllocator::InitStorages() { + for (size_t i = 0; i < data_.size(); ++i) { + StorageEntry *e = data_[i].get(); + TShape shape = mshadow::Shape1(e->max_size); + e->data = NArray(shape, e->ctx); + } +} + +NArray GraphStorageAllocator::Get(StorageID id, TShape shape) { + CHECK_NE(id, kBadStorageID); + StorageEntry *e = data_[id].get(); + return e->data.Slice(0, shape.Size()).Reshape(shape); +} +} // namespace mxnet +#endif // MXNET_SYMBOL_GRAPH_MEMORY_ALLOCATOR_H_ diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index 3bec3427fbb3..5eb0ad14a282 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -152,6 +152,18 @@ bool StaticGraph::InferShape(std::vector *in_shape, return true; } +StaticGraph::Node StaticGraph::CreateSumNode( + const std::vector &grad_source) { + // find multiple gradients, need aggregate + std::ostringstream os_size; + Node agg_node; + agg_node.op.reset(OperatorProperty::Create("ElementWiseSum")); + os_size << grad_source.size(); + agg_node.op->Init({{"size", os_size.str()}}); + agg_node.inputs = grad_source; + return std::move(agg_node); +} + void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, std::vector > *arg_grads) { arg_grads->clear(); @@ -162,14 +174,15 @@ void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, std::map > grad_map; // allocate head gradient nodes for (DataEntry head : heads) { - uint32_t nid = static_cast(nodes.size()); - // create a variable node for gradient input - nodes.push_back(Node()); - Node &node = nodes[nid]; + Node node; std::ostringstream os; os << nodes[head.source_id].name << '_' << head.index << "_grad"; // TODO(bing): add index to name node.name = os.str(); + // node id + uint32_t nid = static_cast(nodes.size()); + nodes.push_back(std::move(node)); + // create a variable node for gradient input DataEntry igrad(nid, 0); head_grad_nodes->push_back(nid); // update gradient map @@ -204,31 +217,25 @@ void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, if (gnodes.size() == 1) { out_grad.push_back(gnodes[0]); } else { - // find multiple gradients, need aggregate - std::ostringstream os_size, os_name; - uint32_t agg_node_id = static_cast(nodes.size()); - nodes.push_back(Node()); - Node &agg_node = nodes[agg_node_id]; - agg_node.op.reset(OperatorProperty::Create("ElementWiseSum")); - os_size << gnodes.size(); - agg_node.op->Init({{"size", os_size.str()}}); + std::ostringstream os_name; + Node agg_node = StaticGraph::CreateSumNode(gnodes); os_name << nodes[nid].name << '_' << i << "_out_grad_agg"; agg_node.name = os_name.str(); - agg_node.inputs = gnodes; + uint32_t agg_node_id = static_cast(nodes.size()); + nodes.push_back(std::move(agg_node)); out_grad.push_back(DataEntry(agg_node_id, 0)); } } // Create a gradient backward node - nodes.push_back(Node()); - uint32_t grad_node_id = static_cast(nodes.size()); - Node &grad_node = nodes[grad_node_id]; + Node grad_node; // Point to the corresponding source grad_node.backward_source_id = nid; // select out the dependent inputs grad_node.inputs = nodes[nid].op->BackwardInputs( out_grad, nodes[nid].inputs, out_data); grad_node.name = nodes[nid].name + "_backward"; - + uint32_t grad_node_id = static_cast(nodes.size()); + nodes.push_back(std::move(grad_node)); // update gradient map for (size_t i = 0; i < nodes[nid].inputs.size(); ++i) { DataEntry idata = nodes[nid].inputs[i]; diff --git a/windows/mxnet.sln b/windows/mxnet.sln deleted file mode 100755 index 16f82f6b6fb1..000000000000 --- a/windows/mxnet.sln +++ /dev/null @@ -1,28 +0,0 @@ - -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 2013 -VisualStudioVersion = 12.0.21005.1 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "mxnet", "mxnet.vcxproj", "{2DA41CBC-B8B2-4696-86CD-9AFBAB029661}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|Win32 = Debug|Win32 - Debug|x64 = Debug|x64 - Release|Win32 = Release|Win32 - Release|x64 = Release|x64 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {2DA41CBC-B8B2-4696-86CD-9AFBAB029661}.Debug|Win32.ActiveCfg = Debug|Win32 - {2DA41CBC-B8B2-4696-86CD-9AFBAB029661}.Debug|Win32.Build.0 = Debug|Win32 - {2DA41CBC-B8B2-4696-86CD-9AFBAB029661}.Debug|x64.ActiveCfg = Debug|x64 - {2DA41CBC-B8B2-4696-86CD-9AFBAB029661}.Debug|x64.Build.0 = Debug|x64 - {2DA41CBC-B8B2-4696-86CD-9AFBAB029661}.Release|Win32.ActiveCfg = Release|Win32 - {2DA41CBC-B8B2-4696-86CD-9AFBAB029661}.Release|Win32.Build.0 = Release|Win32 - {2DA41CBC-B8B2-4696-86CD-9AFBAB029661}.Release|x64.ActiveCfg = Release|x64 - {2DA41CBC-B8B2-4696-86CD-9AFBAB029661}.Release|x64.Build.0 = Release|x64 - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection -EndGlobal diff --git a/windows/mxnet.vcxproj b/windows/mxnet.vcxproj deleted file mode 100755 index 2823478cc51f..000000000000 --- a/windows/mxnet.vcxproj +++ /dev/null @@ -1,148 +0,0 @@ - - - - - Debug - Win32 - - - Debug - x64 - - - Release - Win32 - - - Release - x64 - - - - {2DA41CBC-B8B2-4696-86CD-9AFBAB029661} - Win32Proj - - - - Application - true - v120 - - - Application - true - v120 - - - Application - false - v120 - - - Application - false - v120 - - - - - - - - - - - - - - - - - - - true - - - true - - - true - - - true - - - - WIN32;_DEBUG;_WINDOWS;%(PreprocessorDefinitions) - MultiThreadedDebugDLL - Level3 - ProgramDatabase - Disabled - - - MachineX86 - true - Windows - - - - - WIN32;_DEBUG;_WINDOWS;%(PreprocessorDefinitions) - MultiThreadedDebugDLL - Level3 - ProgramDatabase - Disabled - $(solutionDir)\..\src - - - true - Console - - - - - WIN32;NDEBUG;_WINDOWS;%(PreprocessorDefinitions) - MultiThreadedDLL - Level3 - ProgramDatabase - - - MachineX86 - true - Windows - true - true - - - - - WIN32;NDEBUG;_WINDOWS;%(PreprocessorDefinitions) - MultiThreadedDLL - Level3 - ProgramDatabase - $(solutionDir)\..\src - - - true - Console - true - true - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/windows/mxnet.vcxproj.filters b/windows/mxnet.vcxproj.filters deleted file mode 100755 index 1ff068b088be..000000000000 --- a/windows/mxnet.vcxproj.filters +++ /dev/null @@ -1,48 +0,0 @@ - - - - - {4FC737F1-C7A5-4376-A066-2A32D752A2FF} - cpp;c;cc;cxx;def;odl;idl;hpj;bat;asm;asmx - - - {93995380-89BD-4b04-88EB-625FBE52EBFB} - h;hh;hpp;hxx;hm;inl;inc;xsd - - - {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} - rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav - - - - - Source Files - - - Source Files - - - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - Header Files - - - \ No newline at end of file diff --git a/windows/mxnet.vcxproj.user b/windows/mxnet.vcxproj.user deleted file mode 100755 index ef5ff2a1fae6..000000000000 --- a/windows/mxnet.vcxproj.user +++ /dev/null @@ -1,4 +0,0 @@ - - - - \ No newline at end of file