diff --git a/Makefile b/Makefile index 74944012df8e..b763a406da23 100644 --- a/Makefile +++ b/Makefile @@ -58,14 +58,14 @@ endif BIN = test/api_registry_test test/test_storage OBJ = narray_op_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o storage.o fully_connected_cpu.o static_graph.o +OBJCXX11 = engine.o narray.o c_api.o registry.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o elementwise_sum_cpu.o CUOBJ = SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a LIB_DEP = $(DMLC_CORE)/libdmlc.a ifeq ($(USE_CUDA), 1) - CUOBJ += narray_op_gpu.o fully_connected_gpu.o + CUOBJ += narray_op_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o endif .PHONY: clean all test lint doc @@ -87,7 +87,10 @@ c_api.o: src/c_api.cc operator.o: src/operator/static_operator_wrapper.cc fully_connected_cpu.o: src/operator/fully_connected.cc fully_connected_gpu.o: src/operator/fully_connected.cu - +activation_cpu.o: src/operator/activation.cc +activation_gpu.o: src/operator/activation.cu +elementwise_sum_cpu.o: src/operator/elementwise_sum.cc +elementwise_sum_gpu.o: src/operator/elementwise_sum.cu lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) lib/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index a9a15c4a8007..c6cddce00df8 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -49,10 +49,9 @@ typedef void *DataIterHandle; * \return error info */ MXNET_DLL const char *MXGetLastError(); - -//-------------------------------- +//------------------------------------- // Part 1: NArray creation and deletion -//-------------------------------- +//------------------------------------- /*! * \brief create a NArray handle that is not initialized * can be used to pass in as mutate variables @@ -189,7 +188,6 @@ MXNET_DLL int MXFuncDescribe(FunctionHandle fun, mx_uint *num_scalars, mx_uint *num_mutate_vars, int *type_mask); - /*! * \brief invoke a function, the array size of passed in arguments * must match the values in the @@ -301,8 +299,8 @@ MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolListReturns(SymbolHandle symbol, - mx_uint *out_size, - const char ***out_str_array); + mx_uint *out_size, + const char ***out_str_array); /*! * \brief Compose the symbol on other symbols. * @@ -322,6 +320,37 @@ MXNET_DLL int MXSymbolCompose(SymbolHandle sym, mx_uint num_args, const char** keys, SymbolHandle* args); +/*! + * \brief infer shape of unknown input shapes given the known one. + * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data + * The call will be treated as a kwargs call if key != nullptr or num_args==0, otherwise it is positional. + * + * \param sym symbol handle + * \param num_args numbe of input arguments. + * \param keys the key of keyword args (optional) + * \param arg_ind_ptr the head pointer of the rows in CSR + * \param arg_shape_data the content of the CSR + * \param in_shape_size sizeof the returning array of in_shapes + * \param in_shape_ndim returning array of shape dimensions of eachs input shape. + * \param in_shape_data returning array of pointers to head of the input shape. + * \param out_shape_size sizeof the returning array of out_shapes + * \param out_shape_ndim returning array of shape dimensions of eachs input shape. + * \param out_shape_data returning array of pointers to head of the input shape. + * \param complete whether infer shape completes or more information is needed. + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolInferShape(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_uint *arg_ind_ptr, + const mx_uint *arg_shape_data, + mx_uint *in_shape_size, + const mx_uint **in_shape_ndim, + const mx_uint ***in_shape_data, + mx_uint *out_shape_size, + const mx_uint **out_shape_ndim, + const mx_uint ***out_shape_data, + int *complete); //-------------------------------------------- // Part 4: operator interface on NArray //-------------------------------------------- @@ -352,24 +381,6 @@ MXNET_DLL int MXOpFree(OperatorHandle op); */ MXNET_DLL int MXOpDescribeArgs(mx_uint *out_size, int **out_array); -/*! - * \brief infer shape of unknown input shapes given the known one - * this function do not return the shape of output - * the shapes are packed into a CSR matrix represened by ind_ptr and shape_array - * - * When the function returns, it return a new CSR matrix by updating ind_ptr, - * and return the content in the return value - * - * \param ind_ptr the head pointer of the rows in CSR - * \param shape_array the content of the CSR - * \param out_nout number of output arguments of this operation - * \param out_array another content of CSR with infered shape - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXOpInferShape(mx_uint *ind_ptr, - mx_uint *shape_array, - mx_uint *out_nout, - mx_uint *out_array); /*! * \brief call forward on the operator * \param op the operator handle diff --git a/include/mxnet/context.h b/include/mxnet/context.h index 262ba2e787d4..8dfa618ca180 100644 --- a/include/mxnet/context.h +++ b/include/mxnet/context.h @@ -6,6 +6,8 @@ #ifndef MXNET_CONTEXT_H_ #define MXNET_CONTEXT_H_ +#include "./base.h" + namespace mxnet { /*! \brief Context information about the execution enviroment */ diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 938083dbab33..0fa1fb6a0571 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -40,16 +40,18 @@ enum OpReqType { struct OpContext { /*! \brief whether it is training phase */ int is_train; - /*! \brief Stream we are running on */ - void *stream; + /*! \brief RunContext related resources */ + RunContext run_ctx; /*! \brief Resources requested by the operator */ std::vector requested; /*! - * \brief set the RunContext related parts - * \param ctx the context + * \brief get mshadow stream from Context + * \return the mshadow stream + * \tparam xpu the device type of the stream */ - inline void SetRunContext(const RunContext &ctx) { - stream = ctx.stream; + template + inline mshadow::Stream* get_stream() const { + return static_cast*>(run_ctx.stream); } }; @@ -84,13 +86,22 @@ class Operator { const std::vector &out_data) = 0; /*! * \brief Perform a Backward Operation, write gradient to the in_grad. + * + * Convention: + * out_grad.size() == OperatorProperty.NumVisibleReturns() + * out_data.size() == OperatorProperty.NumReturns() + * out_data can contain additional invisible returns that remembers the + * state carried from the Forward pass. For example mask in the dropout. + * + * The gradients are passed from visible returns in this function. + * * \param ctx runtime context available to this call - * \param out_grad the gradient value we get from output of the Operator + * \param out_grad the gradient value we get from of the Operator. * \param in_data the array of input data. * \param out_data the array of output data. * \param req request types of the saving operation, can be all types. * \param in_grad the array of gradient we need to write to. - * \sa OpReqType, OpContext + * \sa OpReqType, OpContext, OperatorProperty */ virtual void Backward(const OpContext &ctx, const std::vector &out_grad, @@ -115,6 +126,12 @@ class OperatorProperty { * \brief virtual destructor */ virtual ~OperatorProperty() {} + /*! + * \brief Initialize the Operator by setting the parameters + * This function need to be called before all other functions. + * \param kwargs the keyword arguments parameters + */ + virtual void Init(const std::vector >& kwargs) = 0; /*! * \brief Get input arguments of the Operator. * \return vector of arguments. @@ -148,12 +165,6 @@ class OperatorProperty { virtual int NumVisibleReturns() const { return NumReturns(); } - /*! - * \brief Set the parameters of the Operator. - * \param name parameter name - * \param val string for the configuration - */ - virtual void SetParam(const char *name, const char *val) {} /*! * \brief infer the shapes of outputs and unknown input arguments * \param in_shape the shape of input arguments of the operator @@ -166,7 +177,8 @@ class OperatorProperty { * * \param out_shape the shape of outputs of the operator * InferShape will modify the vector to fill output TShape - * \return if the shape inference is successful, return true, else return false. + * \return true if the shape inference is successful, false if there is not enough information. + * \throws dmlc::Error if the known arg_shapes are inconsistent. */ virtual bool InferShape(std::vector *in_shape, std::vector *out_shape) const = 0; diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index dc00f5a33fb6..e24c03a0cd0b 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include "./base.h" @@ -37,8 +38,45 @@ class StaticGraph { uint32_t source_id; /*! \brief index of output from the source. */ uint32_t index; + /*! \brief default constructor */ + DataEntry() {} + /*! + * \brief constructor with source and index + * \param source_id source id + * \param index node index + */ + DataEntry(uint32_t source_id, uint32_t index) + : source_id(source_id), index(index) {} + /*! + * \brief compare equality + * \param other the other entry to compare + * \return whether two entries equals to each other + */ + inline bool operator==(const DataEntry &other) const { + return source_id == other.source_id && index == other.index; + } + /*! + * \brief comparator, allows to use map + * \param other the other entry to compare + * \return whether two entries is smaller than the other + */ + inline bool operator<(const DataEntry &other) const { + if (source_id == other.source_id) return index < other.index; + return source_id < other.source_id; + } }; - /*! \brief Operation Node in static graph */ + /*! + * \brief Operation Node in static graphs. + * There are two types of node, Forward and Backward Node. + * + * - Forward node corresponds to the op.Forward + * - Backward node corresponds to the Backward pass, + * where the corresponding forward node is indicated by backward_source_id. + * The op field in Backward node is nullptr + * + * The reason we explicit support Backward node is to allow special treatment + * such as shape inference and state sharing with Forward pass. + */ struct Node { /*! \brief wrapped operator property */ std::unique_ptr op; @@ -46,13 +84,36 @@ class StaticGraph { std::string name; /*! \brief inputs (node_id, index) for of the nodes*/ std::vector inputs; + /*! + * \brief If this field is nonnegative, this indicates this + * Node is corresponds to a Backward Operation of Operator. + * backward_source_id will points to the corresponding Forward Node. + * + * For normal node, this field is -1. + * When the node is a Backward node, the op field will be nullptr + */ + int32_t backward_source_id; + /*! \brief default constructor */ + Node() : backward_source_id(-1) {} + /*! \return whether the node is forward op node */ + inline bool is_forward() const { + return op != nullptr; + } + /*! \return whether the node is backward op node */ + inline bool is_backward() const { + return backward_source_id != -1; + } + /*! \return whether the node is variable node */ + inline bool is_variable() const { + return op == nullptr && !is_backward(); + } }; /*! \brief all nodes in the graph */ std::vector nodes; - /*! \brief index is nodes that correspods to arguments */ + /*! \brief index of nodes that correspods to arguments */ std::vector arg_nodes; - /*! \brief outputs(heads) of the graph */ - std::vector outputs; + /*! \brief heads outputs of the graph */ + std::vector heads; // funtions to help inference in static graph /*! * \brief Perform a topological sort on the graph @@ -85,8 +146,23 @@ class StaticGraph { * InferShape will modify the vector to fill output TShape * \return if the shape inference is successful, return true, else return false. */ - bool InferShape(std::vector *in_shape, - std::vector *out_shape) const; + bool InferShape(std::vector* in_shape, + std::vector* out_shape) const; + /*! + * \brief Add a full backward pass in the static graph. + * This function will add gradient nodes for each heads, + * and add the backward pass to backprop the gradients all + * the way to the arguments. + * + * This will change the nodes field in the StaticGraph, but will not change other fields. + * The head and input of Backward pass will be returned by head_grad_nodes and arg_grads. + * + * \param head_grad_nodes used to store the created head gradient inputs for backward pass. +<<<<<<< HEAD + * \param arg_grads used to store gradients to args, can be multiple one if an argument is used by operator + */ + void MakeBackwardPass(std::vector *head_grad_nodes, + std::vector > *arg_grads); }; /*! @@ -174,7 +250,7 @@ class Symbol { const std::string& name) const; /*! * \brief infer the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the operator + * \param arg_shapes the shape of input arguments of the operator * this should be of same length as the vector returned by ListArguments * in_shape allows unknown elements, which are checked by shape.ndim() == 0. * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape @@ -182,11 +258,23 @@ class Symbol { * * common practice: set the shape of data input, and usually weight's shape can be infered * - * \param out_shape the shape of outputs of the operator - * InferShape will modify the vector to fill output TShape - * \return if the shape inference is successful, return true, else return false. + * \param out_shapes Use to store the infered shapes of outputs. + * \return true if the shape inference is successful, false if there is not enough information. + * \throws dmlc::Error if the known arg_shapes are inconsistent. + */ + bool InferShape(std::vector *arg_shapes, + std::vector *out_shapes) const; + /*! + * \brief infer the shapes by providing shapes of known arguments. + * \param known_arg_shapes map of argument name to shape of arguments with known shapes. + * \param arg_shapes used to store infered shapes of arguments. + * \param out_shapes used to store infered shapes of outputs. + * \return true if the shape inference is successful, false if there is not enough information. + * \throws dmlc::Error if the known arg_shapes are inconsistent. */ - bool InferShape(std::vector *in_shape, std::vector *out_shape) const; + bool InferShape(const std::unordered_map &known_arg_shapes, + std::vector *arg_shapes, + std::vector *out_shapes) const; /*! * \brief get number of outputs of this symbol * \return number of outputs diff --git a/python/mxnet/narray.py b/python/mxnet/narray.py index 26a2198bd765..61839ecc0a60 100644 --- a/python/mxnet/narray.py +++ b/python/mxnet/narray.py @@ -134,7 +134,7 @@ def shape(self): pdata = ctypes.POINTER(mx_uint)() check_call(_LIB.MXNArrayGetShape( self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) - return tuple(pdata[i] for i in range(ndim.value)) + return tuple(pdata[:ndim.value]) @property def context(self): diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 031b18ab862f..d29bd7285ba9 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -1,11 +1,11 @@ # coding: utf-8 -# pylint: disable=invalid-name, protected-access +# pylint: disable=invalid-name, protected-access, too-many-locals """Symbol support of mxnet""" from __future__ import absolute_import import ctypes from .base import _LIB -from .base import c_array, c_str +from .base import c_array, c_str, mx_uint from .base import SymbolHandle from .base import check_call @@ -136,6 +136,80 @@ def list_returns(self): self.handle, ctypes.byref(size), ctypes.byref(sarr))) return [sarr[i] for i in range(size.value)] + def infer_shape(self, *args, **kwargs): + """Infer the shape of outputs and arguments of given known shapes of arguments. + + User can either pass in the known shapes in positional way or keyword argument way. + Pair of Nones is returned if there is not enough information passed in. + An error will be raised if there is inconsistency found in the known shapes passed in. + + Parameters + ---------- + *args : + Provide shape of arguments in a positional way. + Unknown shape can be marked as None + + **kwargs : + Provide keyword arguments of known shapes. + + Returns + ------- + arg_shapes : list of tuple or None + List of shapes of arguments. + The order is in the same order as list_arguments() + out_shapes : list of tuple or None + List of shapes of outputs. + The order is in the same order as list_returns() + """ + if len(args) != 0 and len(kwargs) != 0: + raise ValueError('Can only specify known argument \ + shapes either by positional or kwargs way.') + sdata = [] + indptr = [0] + if len(args) != 0: + keys = None + for s in args: + if s is not None: + if not isinstance(s, tuple): + raise TypeError('Argument need to be shapes(tuple)') + sdata.extend(s) + indptr.append(len(sdata)) + else: + keys = [] + for k, v in kwargs.items(): + keys.append(c_str(k)) + if not isinstance(v, tuple): + raise TypeError('Argument need to be shapes(tuple)') + sdata.extend(v) + indptr.append(len(sdata)) + arg_shape_size = mx_uint() + arg_shape_ndim = ctypes.POINTER(mx_uint)() + arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))() + out_shape_size = mx_uint() + out_shape_ndim = ctypes.POINTER(mx_uint)() + out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_uint))() + complete = ctypes.c_int() + check_call(_LIB.MXSymbolInferShape( \ + self.handle, len(indptr) - 1, \ + c_array(ctypes.c_char_p, keys), \ + c_array(mx_uint, indptr), \ + c_array(mx_uint, sdata), \ + ctypes.byref(arg_shape_size), \ + ctypes.byref(arg_shape_ndim), \ + ctypes.byref(arg_shape_data), \ + ctypes.byref(out_shape_size), \ + ctypes.byref(out_shape_ndim), \ + ctypes.byref(out_shape_data), \ + ctypes.byref(complete))) + if complete.value != 0: + arg_shapes = [tuple(arg_shape_data[i][:arg_shape_ndim[i]]) \ + for i in range(arg_shape_size.value)] + out_shapes = [tuple(out_shape_data[i][:out_shape_ndim[i]]) \ + for i in range(out_shape_size.value)] + return (arg_shapes, out_shapes) + else: + return (None, None) + def debug_str(self): """Get a debug string. diff --git a/python/mxnet/symbol_creator.py b/python/mxnet/symbol_creator.py index c81deebaef11..d507a9c2871a 100644 --- a/python/mxnet/symbol_creator.py +++ b/python/mxnet/symbol_creator.py @@ -54,7 +54,7 @@ def __call__(self, *args, **kwargs): if isinstance(v, Symbol): symbol_kwargs[k] = v else: - param_keys.append(k) + param_keys.append(c_str(k)) param_vals.append(c_str(str(v))) # create atomic symbol diff --git a/python/test_infer_shape.py b/python/test_infer_shape.py new file mode 100644 index 000000000000..b94388e5546d --- /dev/null +++ b/python/test_infer_shape.py @@ -0,0 +1,19 @@ +# pylint: skip-file +import mxnet as mx + +data = mx.sym.Variable('data') + +fc1 = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=1000) +fc2 = mx.sym.FullyConnected(data=fc1, name='fc2', num_hidden=10) +fc3 = mx.sym.FullyConnected( name='fc2', num_hidden=10) + +print fc2.list_arguments() + +data_shape = (100, 100) +arg_shapes, out_shapes = fc2.infer_shape(data=data_shape) +print dict(zip(fc2.list_arguments(), arg_shapes)) +print dict(zip(fc2.list_returns(), out_shapes)) + +weight_shape= (1, 100) +data_shape = (100, 100) +arg_shapes, out_shapes = fc2.infer_shape(data=data_shape, fc1_weight=weight_shape) diff --git a/src/c_api.cc b/src/c_api.cc index d5a1a67d70c6..ed5446fc816a 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -27,61 +28,76 @@ #message("Warning: Threadlocal is not enabled"); #endif -/*! \brief symbol wrapper to easily hold returning information */ -struct MXAPISymbolWrapper { - /*! \brief the actual symbol */ - mxnet::Symbol sym; +using namespace mxnet; + +/*! \brief entry to to easily hold returning information */ +struct MXAPIThreadLocalEntry { + /*! \brief holds last error message */ + std::string last_error; /*! \brief result holder for returning string */ std::string ret_str; /*! \brief result holder for returning strings */ std::vector ret_vec_str; /*! \brief result holder for returning string pointers */ std::vector ret_vec_charp; + /*! \brief result holder for returning shapes */ + std::vector arg_shapes, out_shapes; + /*! \brief result holder for returning shape dimensions */ + std::vector arg_shape_ndim, out_shape_ndim; + /*! \brief result holder for returning shape pointer */ + std::vector arg_shape_data, out_shape_data; + // helper function to setup return value of shape array + inline static void SetupShapeArrayReturn( + const std::vector &shapes, + std::vector *ndim, + std::vector *data) { + ndim->resize(shapes.size()); + data->resize(shapes.size()); + for (size_t i = 0; i < shapes.size(); ++i) { + ndim->at(i) = shapes[i].ndim(); + data->at(i) = shapes[i].data(); + } + } }; /*! - * \brief helper to store error message in threadlocal storage + * \brief A threadlocal store to store threadlocal variables. + * Will return a thread local singleton of type T + * \tparam T the type we like to store */ -class MXAPIErrorMessageHelper { +class MXAPIThreadLocalStore { public: - /*! \brief get a single instance out from */ - static MXAPIErrorMessageHelper *Get() { - static MXAPIErrorMessageHelper inst; - return &inst; - } - /*! - * \brief a helper function for error handling - * will set the last error to be str_set when it is not NULL - * \param str_set the error to set - * \return a pointer message to last error - */ - static const char *SetGetLastError(const char *str_set) { - // use last_error to record last error - static MX_TREAD_LOCAL std::string *last_error = NULL; - if (last_error == NULL) { - last_error = new std::string(); - Get()->RegisterDelete(last_error); + /*! \brief store return entry */ + typedef MXAPIThreadLocalEntry T; + /*! \return get a thread local singleton */ + static T* Get() { + static MX_TREAD_LOCAL T* ptr = nullptr; + if (ptr == nullptr) { + ptr = new T(); + Singleton()->RegisterDelete(ptr); } - if (str_set != NULL) { - *last_error = str_set; - } - return last_error->c_str(); + return ptr; } private: /*! \brief constructor */ - MXAPIErrorMessageHelper() {} + MXAPIThreadLocalStore() {} /*! \brief destructor */ - ~MXAPIErrorMessageHelper() { + ~MXAPIThreadLocalStore() { for (size_t i = 0; i < data_.size(); ++i) { delete data_[i]; } } + /*! \return singleton of the store */ + static MXAPIThreadLocalStore *Singleton() { + static MXAPIThreadLocalStore inst; + return &inst; + } /*! * \brief register str for internal deletion * \param str the string pointer */ - void RegisterDelete(std::string *str) { + void RegisterDelete(T *str) { std::unique_lock lock(mutex_); data_.push_back(str); lock.unlock(); @@ -89,13 +105,12 @@ class MXAPIErrorMessageHelper { /*! \brief internal mutex */ std::mutex mutex_; /*!\brief internal data */ - std::vector data_; + std::vector data_; }; // NOTE: all functions return 0 upon success // consider add try/catch block for user error // handling in the future -using namespace mxnet; /*! \brief macro to guard beginning and end section of all functions */ #define API_BEGIN() try { @@ -111,7 +126,7 @@ using namespace mxnet; /*! \brief return str message of the last error */ const char *MXGetLastError() { - return MXAPIErrorMessageHelper::SetGetLastError(NULL); + return MXAPIThreadLocalStore::Get()->last_error.c_str(); } /*! @@ -120,7 +135,7 @@ const char *MXGetLastError() { * \return the return value of API after exception is handled */ int MXHandleException(const dmlc::Error &e) { - MXAPIErrorMessageHelper::SetGetLastError(e.what()); + MXAPIThreadLocalStore::Get()->last_error = e.what(); return -1; } @@ -295,24 +310,26 @@ int MXSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, const char **keys, const char **vals, SymbolHandle *out) { - MXAPISymbolWrapper *s = new MXAPISymbolWrapper(); + Symbol *s = new Symbol(); OperatorProperty *op = nullptr; API_BEGIN(); OperatorPropertyEntry *e = static_cast(creator); op = (*e)(); + std::vector > kwargs; for (int i = 0; i < num_param; ++i) { - op->SetParam(keys[i], vals[i]); + kwargs.push_back({std::string(keys[i]), std::string(vals[i])}); } - s->sym = Symbol::Create(op); + op->Init(kwargs); + *s = Symbol::Create(op); *out = s; API_END_HANDLE_ERROR(delete s; delete op); } int MXSymbolCreateVariable(const char *name, SymbolHandle *out) { - MXAPISymbolWrapper *s = new MXAPISymbolWrapper(); + Symbol *s = new Symbol(); API_BEGIN(); - s->sym = Symbol::CreateVariable(name); + *s = Symbol::CreateVariable(name); *out = s; API_END_HANDLE_ERROR(delete s); } @@ -320,71 +337,72 @@ int MXSymbolCreateVariable(const char *name, SymbolHandle *out) { int MXSymbolCreateGroup(mx_uint num_symbols, SymbolHandle *symbols, SymbolHandle *out) { - MXAPISymbolWrapper *s = new MXAPISymbolWrapper(); - MXAPISymbolWrapper **sym_arr = (MXAPISymbolWrapper**)symbols; // NOLINT(*) + Symbol *s = new Symbol(); + Symbol **sym_arr = (Symbol**)symbols; // NOLINT(*) API_BEGIN(); std::vector syms; for (mx_uint i = 0; i < num_symbols; ++i) { - syms.push_back(sym_arr[i]->sym); + syms.push_back(*sym_arr[i]); } - s->sym = Symbol::CreateGroup(syms); + *s = Symbol::CreateGroup(syms); *out = s; API_END_HANDLE_ERROR(delete s); } int MXSymbolFree(SymbolHandle symbol) { API_BEGIN(); - delete static_cast(symbol); + delete static_cast(symbol); API_END(); } int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out) { - MXAPISymbolWrapper *s = new MXAPISymbolWrapper(); - + Symbol *s = new Symbol(); API_BEGIN(); - s->sym = (static_cast(symbol)->sym).Copy(); + *s = static_cast(symbol)->Copy(); *out = s; API_END_HANDLE_ERROR(delete s); } int MXSymbolPrint(SymbolHandle symbol, const char **out_str) { - MXAPISymbolWrapper *s = static_cast(symbol); - + Symbol *s = static_cast(symbol); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); std::ostringstream os; - (s->sym).Print(os); - s->ret_str = os.str(); - *out_str = (s->ret_str).c_str(); + s->Print(os); + ret->ret_str = os.str(); + *out_str = (ret->ret_str).c_str(); API_END(); } int MXSymbolListArguments(SymbolHandle symbol, mx_uint *out_size, const char ***out_str_array) { - MXAPISymbolWrapper *s = static_cast(symbol); + Symbol *s = static_cast(symbol); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); - s->ret_vec_str = std::move((s->sym).ListArguments()); - s->ret_vec_charp.clear(); - for (size_t i = 0; i < s->ret_vec_str.size(); ++i) { - s->ret_vec_charp.push_back(s->ret_vec_str[i].c_str()); + ret->ret_vec_str = std::move(s->ListArguments()); + ret->ret_vec_charp.clear(); + for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { + ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); } - *out_size = static_cast(s->ret_vec_charp.size()); - *out_str_array = dmlc::BeginPtr(s->ret_vec_charp); + *out_size = static_cast(ret->ret_vec_charp.size()); + *out_str_array = dmlc::BeginPtr(ret->ret_vec_charp); API_END(); } int MXSymbolListReturns(SymbolHandle symbol, - mx_uint *out_size, - const char ***out_str_array) { - MXAPISymbolWrapper *s = static_cast(symbol); + mx_uint *out_size, + const char ***out_str_array) { + Symbol *s = static_cast(symbol); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); - s->ret_vec_str = std::move((s->sym).ListReturns()); - s->ret_vec_charp.clear(); - for (size_t i = 0; i < s->ret_vec_str.size(); ++i) { - s->ret_vec_charp.push_back(s->ret_vec_str[i].c_str()); + ret->ret_vec_str = std::move(s->ListReturns()); + ret->ret_vec_charp.clear(); + for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { + ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); } - *out_size = static_cast(s->ret_vec_charp.size()); - *out_str_array = dmlc::BeginPtr(s->ret_vec_charp); + *out_size = static_cast(ret->ret_vec_charp.size()); + *out_str_array = dmlc::BeginPtr(ret->ret_vec_charp); API_END(); } @@ -397,19 +415,68 @@ int MXSymbolCompose(SymbolHandle sym, std::string s_name; if (name != nullptr) s_name = name; - MXAPISymbolWrapper* s = static_cast(sym); + Symbol* s = static_cast(sym); if (keys == nullptr && num_args != 0) { std::vector pos_args; for (mx_uint i = 0; i < num_args; ++i) { - pos_args.push_back(((MXAPISymbolWrapper*)(args[i]))->sym); // NOLINT(*) + pos_args.push_back(*((Symbol*)args[i])); // NOLINT(*) } - (s->sym).Compose(pos_args, s_name); + s->Compose(pos_args, s_name); } else { std::unordered_map kwargs; for (mx_uint i = 0; i < num_args; ++i) { - kwargs[keys[i]] = ((MXAPISymbolWrapper*)(args[i]))->sym; // NOLINT(*) + kwargs[keys[i]] = *((Symbol*)args[i]); // NOLINT(*) + } + s->Compose(kwargs, s_name); + } + API_END(); +} + +int MXSymbolInferShape(SymbolHandle sym, + mx_uint num_args, + const char** keys, + const mx_uint *arg_ind_ptr, + const mx_uint *arg_shape_data, + mx_uint *in_shape_size, + const mx_uint **in_shape_ndim, + const mx_uint ***in_shape_data, + mx_uint *out_shape_size, + const mx_uint **out_shape_ndim, + const mx_uint ***out_shape_data, + int *complete) { + Symbol *s = static_cast(sym); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + bool succ; + API_BEGIN(); + if (keys == nullptr && num_args != 0) { + ret->arg_shapes.clear(); + for (mx_uint i = 0; i < num_args; ++i) { + ret->arg_shapes.push_back(TShape(arg_shape_data + arg_ind_ptr[i], + arg_shape_data + arg_ind_ptr[i+1])); } - (s->sym).Compose(kwargs, s_name); + succ = s->InferShape(&(ret->arg_shapes), &(ret->out_shapes)); + } else { + std::unordered_map kwargs; + for (mx_uint i = 0; i < num_args; ++i) { + kwargs[keys[i]] = TShape(arg_shape_data + arg_ind_ptr[i], + arg_shape_data + arg_ind_ptr[i+1]); + } + succ = s->InferShape(kwargs, &(ret->arg_shapes), &(ret->out_shapes)); + } + if (succ) { + MXAPIThreadLocalEntry::SetupShapeArrayReturn( + ret->arg_shapes, &(ret->arg_shape_ndim), &(ret->arg_shape_data)); + MXAPIThreadLocalEntry::SetupShapeArrayReturn( + ret->out_shapes, &(ret->out_shape_ndim), &(ret->out_shape_data)); + *in_shape_size = static_cast(ret->arg_shapes.size()); + *in_shape_ndim = dmlc::BeginPtr(ret->arg_shape_ndim); + *in_shape_data = dmlc::BeginPtr(ret->arg_shape_data); + *out_shape_size = static_cast(ret->out_shapes.size()); + *out_shape_ndim = dmlc::BeginPtr(ret->out_shape_ndim); + *out_shape_data = dmlc::BeginPtr(ret->out_shape_data); + *complete = 1; + } else { + *complete = 0; } API_END(); } diff --git a/src/operator/activation-inl.h b/src/operator/activation-inl.h new file mode 100644 index 000000000000..fd643a6405da --- /dev/null +++ b/src/operator/activation-inl.h @@ -0,0 +1,142 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file activation-inl.h + * \brief Activation operator + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_ACTIVATION_INL_H_ +#define MXNET_OPERATOR_ACTIVATION_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" + +namespace mxnet { +namespace op { +// Declare enumeration of input order to make code more intuitive. +// // These enums are only visible within this header +enum ActivationOpInputs {kData}; +enum ActivationOpOutputs {kOut}; +enum ActivationOpType {kReLU, kSigmoid, kTanh}; + +struct ActivationParam : public dmlc::Parameter { + // use int for enumeration + int type; + DMLC_DECLARE_PARAMETER(ActivationParam) { + DMLC_DECLARE_FIELD(type).set_default(kReLU).add_enum("relu", kReLU).\ + add_enum("sigmoid", kSigmoid).add_enum("tanh", kTanh); + } +}; + +/** + * \brief This is the implementation of activation operator. + * \tparam xpu The device that the op will be executed on. + */ +template +class ActivationOp : public Operator { + public: + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + Stream *s = ctx.get_stream(); + Tensor data = in_data[kData].FlatTo2D(s); + Tensor out = out_data[kOut].FlatTo2D(s); + Assign(out, req[kOut], F(data)); + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1); + CHECK(in_data.size() == 1 && in_grad.size() == 1); + CHECK_EQ(req.size(), 1); + Stream *s = ctx.get_stream(); + Tensor m_out_grad = out_grad[kOut].FlatTo2D(s); + Tensor m_out_data = out_data[kOut].FlatTo2D(s); + Tensor m_in_grad = in_grad[kData].FlatTo2D(s); + Assign(m_in_grad, req[kData], F(m_out_data) * m_out_grad); + } +}; // class ActivationOp + +// Decalre Factory function, used for dispatch specialization +template +Operator* CreateOp(ActivationParam type); + +#if DMLC_USE_CXX11 +class ActivationProp : public OperatorProperty { + public: + virtual void Init(const std::vector >& kwargs) { + // TODO(bing) change directly to vector of pairs begin end + std::map kmap(kwargs.begin(), kwargs.end()); + param_.Init(kmap); + } + + virtual bool InferShape(std::vector *in_shape, + std::vector *out_shape) const { + using namespace mshadow; + CHECK_EQ(in_shape->size(), 1) << "Input:[data]"; + const TShape &dshape = in_shape->at(0); + if (dshape.ndim() == 0) return false; + out_shape->clear(); + out_shape->push_back(dshape); + return true; + } + + virtual OperatorProperty* Copy() const { + auto ptr = new ActivationProp(); + ptr->param_ = param_; + return ptr; + } + + virtual std::string TypeString() const { + return "Activation"; + } + + // decalre dependency and inplace optimization options + virtual std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const { + return {out_grad[kOut], out_data[kOut]}; + } + + virtual std::vector > BackwardInplaceOption( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &in_grad) const { + return {{out_grad[kOut], in_grad[kData]}}; + } + + virtual std::vector > ForwardInplaceOption( + const std::vector &in_data, + const std::vector &out_data) const { + return {{in_data[kData], out_data[kOut]}}; + } + + Operator* CreateOperator(Context ctx) const; + + private: + ActivationParam param_; +}; +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_ACTIVATION_INL_H_ + diff --git a/src/operator/activation.cc b/src/operator/activation.cc new file mode 100644 index 000000000000..275588e099af --- /dev/null +++ b/src/operator/activation.cc @@ -0,0 +1,36 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file activation.cc + * \brief activation op + * \author Bing Xu +*/ + +#include +#include "./activation-inl.h" +#include "./mshadow_op.h" + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(ActivationParam param) { + switch (param.type) { + case kReLU: return new ActivationOp(); + case kSigmoid: return new ActivationOp(); + case kTanh: return new ActivationOp(); + default: + LOG(FATAL) << "unknown activation type"; + return NULL; + } +} + +// DO_BIND_DISPATCH comes from operator_common.h +Operator *ActivationProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp, param_); +} + +DMLC_REGISTER_PARAMETER(ActivationParam); + +REGISTER_OP_PROPERTY(Activation, ActivationProp); +} // namespace op +} // namespace mxnet + diff --git a/src/operator/activation.cu b/src/operator/activation.cu new file mode 100644 index 000000000000..5b7b576e59d7 --- /dev/null +++ b/src/operator/activation.cu @@ -0,0 +1,25 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file activation.cu + * \brief + * \author Bing Xu +*/ +#include "./activation-inl.h" +#include "./mshadow_op.h" + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(ActivationParam param) { + switch(param.type) { + case kReLU: return new ActivationOp(); + case kSigmoid: return new ActivationOp(); + case kTanh: return new ActivationOp(); + default: + LOG(FATAL) << "unknown activation"; + return NULL; + } +} +} // op +} // namespace mxnet + diff --git a/src/operator/elementwise_sum-inl.h b/src/operator/elementwise_sum-inl.h new file mode 100644 index 000000000000..65a6ba1d5c99 --- /dev/null +++ b/src/operator/elementwise_sum-inl.h @@ -0,0 +1,173 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file elemementwise_sum-inl.h + * \brief elementwise sum + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_ELEMENTWISE_SUM_INL_H_ +#define MXNET_OPERATOR_ELEMENTWISE_SUM_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" + +namespace mxnet { +namespace op { + +enum ElementWiseSumOpInputs {kData0, kData1, kData2, kData3}; +enum ElementWiseSumOpOutputs {kOut}; + +struct ElementWiseSumParam : public dmlc::Parameter { + int size; + DMLC_DECLARE_PARAMETER(ElementWiseSumParam) { + DMLC_DECLARE_FIELD(size).set_range(1, 100); + } +}; + +template +class ElementWiseSumOp : public Operator { + public: + explicit ElementWiseSumOp(ElementWiseSumParam param) + : size_(param.size) {} + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(static_cast(in_data.size()), size_); + CHECK_EQ(out_data.size(), 1); + if (req[kOut] == kNullOp) return; + + Stream *s = ctx.get_stream(); + Tensor out = out_data[kOut].FlatTo2D(s); + switch (size_) { + case 2: { + Tensor in_0 = in_data[kData0].FlatTo2D(s); + Tensor in_1 = in_data[kData1].FlatTo2D(s); + Assign(out, req[kOut], in_0 + in_1); + break; + } + case 3: { + Tensor in_0 = in_data[kData0].FlatTo2D(s); + Tensor in_1 = in_data[kData1].FlatTo2D(s); + Tensor in_2 = in_data[kData2].FlatTo2D(s); + Assign(out, req[kOut], in_0 + in_1 + in_2); + break; + } + case 4: { + Tensor in_0 = in_data[kData0].FlatTo2D(s); + Tensor in_1 = in_data[kData1].FlatTo2D(s); + Tensor in_2 = in_data[kData2].FlatTo2D(s); + Tensor in_3 = in_data[kData3].FlatTo2D(s); + Assign(out, req[kOut], in_0 + in_1 + in_2 + in_3); + break; + } + default: { + Tensor in_0 = in_data[kData0].FlatTo2D(s); + Assign(out, req[kOut], in_0); + for (int i = 0; i < size_; ++i) { + out += in_data[i].FlatTo2D(s); + } + } + } + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), static_cast(size_)); + Stream *s = ctx.get_stream(); + Tensor ograd = out_grad[kOut].FlatTo2D(s); + + for (int i = 0; i < size_; ++i) { + if (req[i] == kNullOp || req[i] == kWriteInplace) continue; + Tensor igrad = in_grad[i].FlatTo2D(s); + Assign(igrad, req[i], ograd); + } + } + + private: + int size_; +}; // class ElementWiseSumOp + +template +Operator* CreateOp(ElementWiseSumParam param); + +#if DMLC_USE_CXX11 +class ElementWiseSumProp : public OperatorProperty { + public: + virtual void Init(const std::vector >& kwargs) { + // TODO(bing) change directly to vector of pairs begin end + std::map kmap(kwargs.begin(), kwargs.end()); + param_.Init(kmap); + } + + virtual bool InferShape(std::vector *in_shape, + std::vector *out_shape) const { + using namespace mshadow; + CHECK_EQ(in_shape->size(), static_cast(param_.size)); + const TShape &dshape = in_shape->at(0); + if (dshape.ndim() == 0) return false; + for (int i = 1; i < param_.size; ++i) { + SHAPE_ASSIGN_CHECK(*in_shape, i, dshape); + } + out_shape->clear(); + out_shape->push_back(dshape); + return true; + } + + virtual OperatorProperty* Copy() const { + auto ptr = new ElementWiseSumProp(); + ptr->param_ = param_; + return ptr; + } + + virtual std::string TypeString() const { + return "ElementWiseSum"; + } + + virtual std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const { + return out_grad; + } + + virtual std::vector > BackwardInplaceOption( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &in_grad) const { + return {{out_grad[0], in_grad[0]}}; + } + + virtual std::vector > ForwardInplaceOption( + const std::vector &in_data, + const std::vector &out_data) const { + return {{in_data[0], out_data[0]}}; + } + + Operator* CreateOperator(Context ctx) const; + + private: + ElementWiseSumParam param_; +}; // class ElementWiseSumProp + +#endif // DMLC_USE_CXX11 + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_ELEMENTWISE_SUM_INL_H_ diff --git a/src/operator/elementwise_sum.cc b/src/operator/elementwise_sum.cc new file mode 100644 index 000000000000..38e29141c7b3 --- /dev/null +++ b/src/operator/elementwise_sum.cc @@ -0,0 +1,24 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file elementwise_sum.cc + * \brief elementwise sum operator +*/ +#include +#include "./elementwise_sum-inl.h" +namespace mxnet { +namespace op { +template<> +Operator* CreateOp(ElementWiseSumParam param) { + return new ElementWiseSumOp(param); +} + +// DO_BIND_DISPATCH comes from static_operator_common.h +Operator* ElementWiseSumProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp, param_); +} + +DMLC_REGISTER_PARAMETER(ElementWiseSumParam); + +REGISTER_OP_PROPERTY(ElementWiseSum, ElementWiseSumProp); +} // namespace op +} // namespace mxnet diff --git a/src/operator/elementwise_sum.cu b/src/operator/elementwise_sum.cu new file mode 100644 index 000000000000..7a9b443dad82 --- /dev/null +++ b/src/operator/elementwise_sum.cu @@ -0,0 +1,14 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file elementwise_sum.cu + * \brief elementwise sum operator +*/ +#include "./elementwise_sum-inl.h" +namespace mxnet { +namespace op { +template<> +Operator* CreateOp(ElementWiseSumParam param) { + return new ElementWiseSumOp(param); +} +} // namespace op +} // namespace mxnet diff --git a/src/operator/fully_connected-inl.h b/src/operator/fully_connected-inl.h index 5c54d37220ee..e92c9f1f66dd 100644 --- a/src/operator/fully_connected-inl.h +++ b/src/operator/fully_connected-inl.h @@ -7,7 +7,9 @@ #define MXNET_OPERATOR_FULLY_CONNECTED_INL_H_ #include +#include #include +#include #include #include #include @@ -22,6 +24,17 @@ namespace op { enum FullyConnectedOpInputs {kData, kWeight, kBias}; enum FullyConnectedOpOutputs {kOut}; +struct FullyConnectedParam : public dmlc::Parameter { + int num_hidden; + bool no_bias; + DMLC_DECLARE_PARAMETER(FullyConnectedParam) { + // TODO(bing) change to only set lower bound + // add support for boolean + DMLC_DECLARE_FIELD(num_hidden).set_range(1, 100000); + DMLC_DECLARE_FIELD(no_bias).set_default(false); + } +}; + /** * \brief This is the implementation of fully connected operator. * \tparam xpu The device that the op will be executed on. @@ -29,7 +42,7 @@ enum FullyConnectedOpOutputs {kOut}; template class FullyConnectedOp : public Operator { public: - explicit FullyConnectedOp(Param p) { + explicit FullyConnectedOp(FullyConnectedParam p) { this->param_ = p; } @@ -40,17 +53,17 @@ class FullyConnectedOp : public Operator { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(req[kOut], kWriteTo); - size_t expected = param_.no_bias == 0 ? 3 : 2; + size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1); // TODO(bing): check the BLAS Handle, be careful // maybe need blas handle from context - Stream *s = static_cast *>(ctx.stream); + Stream *s = ctx.get_stream(); Tensor data = in_data[kData].FlatTo2D(s); Tensor wmat = in_data[kWeight].get(s); Tensor out = out_data[kOut].FlatTo2D(s); out = dot(data, wmat.T()); - if (param_.no_bias == 0) { + if (!param_.no_bias) { Tensor bias = in_data[kBias].get(s); out += repmat(bias, data.size(0)); } @@ -65,12 +78,12 @@ class FullyConnectedOp : public Operator { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); - size_t expected = param_.no_bias == 0 ? 3 : 2; + size_t expected = param_.no_bias ? 2 : 3; CHECK(in_data.size() == expected && in_grad.size() == expected); CHECK_EQ(req.size(), expected); // TODO(bing): check the BLAS Handle, be careful // maybe need blas handle from context - Stream *s = static_cast *>(ctx.stream); + Stream *s = ctx.get_stream(); Tensor data = in_data[kData].FlatTo2D(s); Tensor wmat = in_data[kWeight].get(s); Tensor grad = out_grad[kOut].FlatTo2D(s); @@ -80,7 +93,7 @@ class FullyConnectedOp : public Operator { Tensor gwmat = in_grad[kWeight].get(s); Assign(gwmat, req[kWeight], dot(grad.T(), data)); // gradient of bias - if (param_.no_bias == 0) { + if (!param_.no_bias) { Tensor gbias = in_grad[kBias].get(s); Assign(gbias, req[kBias], sum_rows(grad)); } @@ -90,49 +103,59 @@ class FullyConnectedOp : public Operator { } private: - /** The param of the fully connected layer.*/ - Param param_; + FullyConnectedParam param_; }; // class FullyConnectedOp // Decalre Factory function, used for dispatch specialization template -Operator* CreateFullyConnectedOp(Param param); +Operator* CreateOp(FullyConnectedParam param); #if DMLC_USE_CXX11 class FullyConnectedProp : public OperatorProperty { public: virtual std::vector ListArguments() const { - if (param_.no_bias == 0) { + if (!param_.no_bias) { return {"data", "weight", "bias"}; } else { return {"data", "weight"}; } } - virtual void SetParam(const char *name, const char *val) { - param_.SetParam(name, val); + virtual void Init(const std::vector >& kwargs) { + // TODO(bing) change directly to vector of pairs begin end + std::map kmap(kwargs.begin(), kwargs.end()); + param_.Init(kmap); } virtual bool InferShape(std::vector *in_shape, std::vector *out_shape) const { using namespace mshadow; - if (param_.no_bias == 0) { + if (!param_.no_bias) { CHECK_EQ(in_shape->size(), 3) << "Input:[data, weight, bias]"; } else { CHECK_EQ(in_shape->size(), 2) << "Input:[data, weight]"; } CHECK_GT(param_.num_hidden, 0); const TShape &dshape = (*in_shape)[0]; - CHECK_EQ(dshape.ndim(), 4) << \ - "Input data should be 4D in batch-1-1-hidden"; - CHECK_NE(dshape.ndim(), 0) << "Require data shape to be known"; - ShapeAssignCheck((*in_shape)[kWeight], Shape2(param_.num_hidden, dshape[3])); - if (param_.no_bias == 0) { - ShapeAssignCheck((*in_shape)[kBias], Shape1(param_.num_hidden)); + // require data to be known + if (dshape.ndim() == 0) return false; + + index_t num_input; + if (dshape.ndim() == 4) { + // TODO(bing) consider deprecate 4D input + CHECK(dshape[1] == 1 && dshape[2] == 1); + num_input = dshape[3]; + } else { + CHECK_EQ(dshape.ndim(), 2) + << "FullyConnecteded: Input data should be 2D in (batch, num_hidden)"; + num_input = dshape[1]; + } + SHAPE_ASSIGN_CHECK(*in_shape, kWeight, Shape2(param_.num_hidden, num_input)); + if (!param_.no_bias) { + SHAPE_ASSIGN_CHECK(*in_shape, kBias, Shape1(param_.num_hidden)); } out_shape->clear(); - out_shape->push_back(dshape); - (*out_shape)[0][3] = param_.num_hidden; + out_shape->push_back(Shape2(dshape[0], param_.num_hidden)); return true; } @@ -164,7 +187,7 @@ class FullyConnectedProp : public OperatorProperty { Operator* CreateOperator(Context ctx) const; private: - Param param_; + FullyConnectedParam param_; }; // class FullyConnectedSymbol #endif } // namespace op diff --git a/src/operator/fully_connected.cc b/src/operator/fully_connected.cc index 362d3c5698aa..7d529cb3ed64 100644 --- a/src/operator/fully_connected.cc +++ b/src/operator/fully_connected.cc @@ -8,15 +8,17 @@ namespace mxnet { namespace op { template<> -Operator* CreateFullyConnectedOp(Param param) { +Operator* CreateOp(FullyConnectedParam param) { return new FullyConnectedOp(param); } // DO_BIND_DISPATCH comes from static_operator_common.h Operator* FullyConnectedProp::CreateOperator(Context ctx) const { - DO_BIND_DISPATCH(CreateFullyConnectedOp, param_); + DO_BIND_DISPATCH(CreateOp, param_); } +DMLC_REGISTER_PARAMETER(FullyConnectedParam); + REGISTER_OP_PROPERTY(FullyConnected, FullyConnectedProp); } // namespace op } // namespace mxnet diff --git a/src/operator/fully_connected.cu b/src/operator/fully_connected.cu index 223ef5166cc9..b97df8afb44c 100644 --- a/src/operator/fully_connected.cu +++ b/src/operator/fully_connected.cu @@ -7,7 +7,7 @@ namespace mxnet { namespace op { template<> -Operator* CreateFullyConnectedOp(Param param) { +Operator* CreateOp(FullyConnectedParam param) { return new FullyConnectedOp(param); } } // namespace op diff --git a/src/operator/static_operator/mshadow_op.h b/src/operator/mshadow_op.h similarity index 87% rename from src/operator/static_operator/mshadow_op.h rename to src/operator/mshadow_op.h index bb33471f168a..010cf0ce7cc9 100644 --- a/src/operator/static_operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -1,17 +1,18 @@ /*! * Copyright (c) 2015 by Contributors * \file mshadow_op.h - * \brief extra mshadow operation for mxnet + * \brief * \author Bing Xu - */ -#ifndef MXNET_OPERATOR_STATIC_OPERATOR_MSHADOW_OP_H_ -#define MXNET_OPERATOR_STATIC_OPERATOR_MSHADOW_OP_H_ +*/ +#ifndef MXNET_OPERATOR_MSHADOW_OP_H_ +#define MXNET_OPERATOR_MSHADOW_OP_H_ + #include -#include namespace mxnet { -/*! \brief operations for ActivationLayer */ namespace op { +namespace mshadow_op { +/*! \brief identity Operation */ struct identity { MSHADOW_XINLINE static real_t Map(real_t a) { return a; @@ -98,9 +99,7 @@ struct square_root { return sqrt(a); } }; - +} // namespace mshadow_op } // namespace op } // namespace mxnet - -#endif // MXNET_OPERATOR_STATIC_OPERATOR_MSHADOW_OP_H_ - +#endif // MXNET_OPERATOR_MSHADOW_OP_H_ diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 87b581f28278..eea731c8fbe6 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -11,6 +11,7 @@ #include #include #include +#include namespace mxnet { namespace op { @@ -34,20 +35,39 @@ inline void Assign(OType &out, // NOLINT(*) default: LOG(FATAL) << "not reached"; } } + +/*! \brief exception throwed by InferShape error */ +struct InferShapeError { + /*! \brief analyze message */ + std::string msg; + /*! \brief corresponding input index */ + int index; + // constructor + InferShapeError(std::string msg, int index) + : msg(msg), index(index) {} +}; + /*! - * \brief assign shape to out if out is unknown - * otherwise check consistency - * \param out the output shape to be stored + * \brief macro assign shape to out if out is unknown otherwise check consistency + * Use macro so we can see the error file more clearly + * \param shape_array the shape array to store the result + * \param index the index of in the array * \param shape the infered shape */ -template -inline void ShapeAssignCheck(TShape &out, const TS &shape) { // NOLINT(*) - if (out.ndim() == 0) { - out = shape; - } else { - CHECK(out == shape) << "InferShape:: shape inconsistent"; +#define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \ + { \ + auto &out = (shape_array)[index]; \ + if (out.ndim() == 0) { \ + out = shape; \ + } else { \ + if (out != shape) { \ + std::ostringstream os; \ + os << "Shape inconsistent, Provided " << '='<< out << ',' \ + << " inferred shape=" << shape; \ + throw ::mxnet::op::InferShapeError(os.str(), index); \ + } \ + } \ } -} // helper macro to implement bind dispatch #if MXNET_USE_CUDA diff --git a/src/operator/param.h b/src/operator/param.h index e1f6b4ee58d8..f0ce5886e2fb 100644 --- a/src/operator/param.h +++ b/src/operator/param.h @@ -35,10 +35,6 @@ struct Param { int no_bias; /*! \brief maximum temp_col_size allowed in each layer */ int temp_col_max; - /*! \brief number of input channels */ - int num_input_channel; - /*! \brief number of input hidden nodes, used by fullc */ - int num_input_node; /*! \brief reserved fields, for future compatibility */ int reserved[64]; @@ -48,11 +44,9 @@ struct Param { } inline void SetParam(const char *name, const char* val) { - if (!strcmp(name, "nhidden")) num_hidden = atoi(val); - if (!strcmp(name, "num_input_node")) num_input_node = atoi(val); - if (!strcmp(name, "num_input_channel")) num_input_channel = atoi(val); - if (!strcmp(name, "nchannel")) num_channel = atoi(val); - if (!strcmp(name, "ngroup")) num_group = atoi(val); + if (!strcmp(name, "num_hidden")) num_hidden = atoi(val); + if (!strcmp(name, "num_channel")) num_channel = atoi(val); + if (!strcmp(name, "num_group")) num_group = atoi(val); if (!strcmp(name, "kernel_size")) { kernel_y = kernel_x = atoi(val); } diff --git a/src/operator/static_operator/activation_op-inl.h b/src/operator/static_operator/activation_op-inl.h deleted file mode 100644 index cfb0b7cec8b5..000000000000 --- a/src/operator/static_operator/activation_op-inl.h +++ /dev/null @@ -1,61 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file activation_op-inl.h - * \brief activation operator of mxnet - */ - -#ifndef MXNET_OPERATOR_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ -#define MXNET_OPERATOR_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ - -#include -#include -#include -#include "./static_operator_common.h" - -namespace mxnet { -namespace op { -template -class ActivationOp : public StaticOperator { - public: - virtual void InferShape(std::vector *in_shape, - std::vector *out_shape) { - CHECK_EQ(in_shape->size(), 1) << "Only 1 input is allowed"; - CHECK_NE((*in_shape)[0].ndim(), 0) << "Require data shape to be known"; - out_shape->clear(); - out_shape->push_back((*in_shape)[0]); - } - virtual void Forward(Option opt, - RunContext ctx, - const std::vector &in_data, - const std::vector &out_data) { - CHECK_EQ(out_data.size(), 1); - CHECK_EQ(in_data.size(), 1); - mshadow::Stream *stream = \ - static_cast *>(ctx.stream); - mshadow::Tensor in = in_data[0].FlatTo2D(stream); - mshadow::Tensor out = out_data[0].FlatTo2D(stream); - out = mshadow::expr::F(in); - } - virtual void Backward(RunContext ctx, - const std::vector &grad_next, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &out_grad, - const std::vector &req) { - CHECK_EQ(grad_next.size(), 1); - CHECK_EQ(in_data.size(), 1); - CHECK_EQ(out_grad.size(), 1); - CHECK_EQ(req.size(), 1); - mshadow::Stream *stream = \ - static_cast *>(ctx.stream); - mshadow::Tensor grad = grad_next[0].FlatTo2D(stream); - mshadow::Tensor data = in_data[0].FlatTo2D(stream); - mshadow::Tensor out = out_grad[0].FlatTo2D(stream); - Assign(out, req[0], mshadow::expr::F( - mshadow::expr::F(data)) * grad); - } -}; // class ActivationOp -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_STATIC_OPERATOR_ACTIVATION_OP_INL_H_ diff --git a/src/registry.cc b/src/registry.cc index 42fef1df3423..f64980d8bacc 100644 --- a/src/registry.cc +++ b/src/registry.cc @@ -25,12 +25,18 @@ Registry *Registry::Get() { return &instance; } -#if DMLC_USE_CXX11 + template NArrayFunctionEntry &Registry::Register(const std::string& name); template Registry *Registry::Get(); -#endif template OperatorPropertyEntry &Registry::Register(const std::string& name); template Registry *Registry::Get(); +// implementation of all factory functions +OperatorProperty *OperatorProperty::Create(const char* type_name) { + auto *creator = Registry::Find(type_name); + CHECK_NE(creator, nullptr) + << "Cannot find Operator " << type_name << " in registry"; + return (*creator)(); +} } // namespace mxnet diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index 5419e26afe86..3bec3427fbb3 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -7,14 +7,19 @@ #include #include #include +#include +#include "../operator/operator_common.h" namespace mxnet { std::vector StaticGraph::TopoSort() const { std::vector out_degree(nodes.size(), 0); - for (const Node &n : nodes) { - for (const DataEntry &e : n.inputs) { + for (const Node& n : nodes) { + for (const DataEntry& e : n.inputs) { ++out_degree[e.source_id]; } + if (n.is_backward()) { + ++out_degree[n.backward_source_id]; + } } std::vector ret(nodes.size()); auto result = ret.rbegin(); @@ -29,12 +34,17 @@ std::vector StaticGraph::TopoSort() const { queue.pop(); *result = node_id; ++result; - for (const DataEntry &e : nodes[node_id].inputs) { - out_degree[e.source_id] -= 1; - if (out_degree[e.source_id] == 0) { + const Node& n = nodes[node_id]; + for (const DataEntry& e : n.inputs) { + if (--out_degree[e.source_id] == 0) { queue.push(e.source_id); } } + if (n.is_backward()) { + if (--out_degree[n.backward_source_id] == 0) { + queue.push(n.backward_source_id); + } + } } return std::move(ret); } @@ -42,19 +52,73 @@ std::vector StaticGraph::TopoSort() const { bool StaticGraph::InferNodeShapes(const std::vector &topo_order, std::vector > *node_out_shapes) const { for (uint32_t nid : topo_order) { - const Node &node = nodes[nid]; - if (node.op != nullptr) { + const Node& node = nodes[nid]; + if (node.is_forward()) { std::vector in_shape; - for (const DataEntry &e : node.inputs) { + for (const DataEntry& e : node.inputs) { in_shape.push_back((*node_out_shapes)[e.source_id][e.index]); } - if (!node.op->InferShape(&in_shape, &(*node_out_shapes)[nid])) return false; + try { + if (!node.op->InferShape(&in_shape, &(*node_out_shapes)[nid])) return false; + } catch (const op::InferShapeError &err) { + // error handling + const std::string &op_name = node.name; + std::string arg_name = node.op->ListArguments()[err.index]; + std::ostringstream os; + os << "InferShape Error in " + << op_name << "\'s" << ' ' << arg_name << " argument\n"; + auto &source = nodes[node.inputs[err.index].source_id]; + if (source.is_variable()) { + os << "Corresponding keyword of symbol: " << source.name << '\n' << err.msg; + } + throw dmlc::Error(os.str()); + } for (size_t i = 0; i < node.inputs.size(); ++i) { - const DataEntry &e = node.inputs[i]; + const DataEntry& e = node.inputs[i]; (*node_out_shapes)[e.source_id][e.index] = in_shape[i]; } + } else if (nodes[nid].is_backward()) { + // simply use shapes from forward pass to assign backward shape + const Node& forward = nodes[node.backward_source_id]; + CHECK(forward.is_forward()); + std::vector& in_grad_shapes = (*node_out_shapes)[nid]; + CHECK(in_grad_shapes.size() == forward.inputs.size()); + // assign the input shape to output gradients + for (size_t i = 0; i < forward.inputs.size(); ++i) { + const DataEntry &e = forward.inputs[i]; + try { + SHAPE_ASSIGN_CHECK(in_grad_shapes, i, (*node_out_shapes)[e.source_id][e.index]); + } catch (const op::InferShapeError &err) { + const std::string &op_name = forward.name; + std::string arg_name = forward.op->ListArguments()[e.index]; + std::ostringstream os; + os << "InferShape Error in " + << op_name << "\'s" << ' ' << arg_name << " gradient argument\n" + << err.msg; + throw dmlc::Error(os.str()); + } + } + // consistent check for input shapes + auto& out_data_shapes = (*node_out_shapes)[node.backward_source_id]; + // use BackwardInputs to select entries corresponding to node.inputs + auto in_shape = forward.op->BackwardInputs( + out_data_shapes, in_grad_shapes, out_data_shapes); + for (size_t i = 0; i < node.inputs.size(); ++i) { + const DataEntry& e = node.inputs[i]; + try { + SHAPE_ASSIGN_CHECK((*node_out_shapes)[e.source_id], e.index, in_shape[i]); + } catch (const op::InferShapeError &err) { + const std::string &op_name = nodes[e.source_id].name; + std::ostringstream os; + os << "InferShape Error in " + << op_name << "\'s" << " gradient values\n" + << err.msg; + throw dmlc::Error(os.str()); + } + } } } + // TODO(bing) assign shape for head gradient return true; } @@ -63,8 +127,10 @@ bool StaticGraph::InferShape(std::vector *in_shape, std::vector > node_out_shapes(nodes.size()); for (size_t i = 0; i < nodes.size(); ++i) { int nout = 1; - if (nodes[i].op != nullptr) { + if (nodes[i].is_forward()) { nout = nodes[i].op->NumReturns(); + } else if (nodes[i].is_backward()) { + nout = static_cast(nodes[nodes[i].backward_source_id].inputs.size()); } node_out_shapes[i].resize(nout); } @@ -78,10 +144,110 @@ bool StaticGraph::InferShape(std::vector *in_shape, for (size_t i = 0; i < arg_nodes.size(); ++i) { (*in_shape)[i] = node_out_shapes[arg_nodes[i]][0]; } - for (size_t i = 0; i < outputs.size(); ++i) { - DataEntry e = outputs[i]; + out_shape->resize(heads.size()); + for (size_t i = 0; i < heads.size(); ++i) { + const DataEntry &e = heads[i]; (*out_shape)[i] = node_out_shapes[e.source_id][e.index]; } return true; } + +void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, + std::vector > *arg_grads) { + arg_grads->clear(); + head_grad_nodes->clear(); + // get topo order of nodes, before new nodes are added + std::vector topo_order = TopoSort(); + // map out_data entry to out_grad + std::map > grad_map; + // allocate head gradient nodes + for (DataEntry head : heads) { + uint32_t nid = static_cast(nodes.size()); + // create a variable node for gradient input + nodes.push_back(Node()); + Node &node = nodes[nid]; + std::ostringstream os; + os << nodes[head.source_id].name << '_' << head.index << "_grad"; + // TODO(bing): add index to name + node.name = os.str(); + DataEntry igrad(nid, 0); + head_grad_nodes->push_back(nid); + // update gradient map + auto it = grad_map.find(head); + if (it == grad_map.end()) { + grad_map[head] = {igrad}; + } else { + it->second.push_back(igrad); + } + } + // do backward pass traverse + for (auto it = topo_order.rbegin(); it != topo_order.rend(); ++it) { + uint32_t nid = *it; + // skip variables + if (nodes[nid].is_variable()) continue; + CHECK(nodes[nid].is_forward()) << "Do not support Backward of Backward"; + // get out_grad and out_data entry + std::vector out_grad, out_data; + // nvisible is out_grad.size() + int nvisible = nodes[nid].op->NumVisibleReturns(); + // ntotal is out_data.size() + int ntotal = nodes[nid].op->NumReturns(); + // check all outpus + for (int i = 0; i < ntotal; ++i) { + DataEntry odata(nid, static_cast(i)); + out_data.push_back(odata); + if (i >= nvisible) continue; + // get out_grad + auto it = grad_map.find(odata); + CHECK(it != grad_map.end()) << "bad graph"; + std::vector &gnodes = it->second; + if (gnodes.size() == 1) { + out_grad.push_back(gnodes[0]); + } else { + // find multiple gradients, need aggregate + std::ostringstream os_size, os_name; + uint32_t agg_node_id = static_cast(nodes.size()); + nodes.push_back(Node()); + Node &agg_node = nodes[agg_node_id]; + agg_node.op.reset(OperatorProperty::Create("ElementWiseSum")); + os_size << gnodes.size(); + agg_node.op->Init({{"size", os_size.str()}}); + os_name << nodes[nid].name << '_' << i << "_out_grad_agg"; + agg_node.name = os_name.str(); + agg_node.inputs = gnodes; + out_grad.push_back(DataEntry(agg_node_id, 0)); + } + } + // Create a gradient backward node + nodes.push_back(Node()); + uint32_t grad_node_id = static_cast(nodes.size()); + Node &grad_node = nodes[grad_node_id]; + // Point to the corresponding source + grad_node.backward_source_id = nid; + // select out the dependent inputs + grad_node.inputs = nodes[nid].op->BackwardInputs( + out_grad, nodes[nid].inputs, out_data); + grad_node.name = nodes[nid].name + "_backward"; + + // update gradient map + for (size_t i = 0; i < nodes[nid].inputs.size(); ++i) { + DataEntry idata = nodes[nid].inputs[i]; + DataEntry igrad(grad_node_id, static_cast(i)); + auto it = grad_map.find(idata); + if (it == grad_map.end()) { + grad_map[idata] = {igrad}; + } else { + it->second.push_back(igrad); + } + } + } + // create return values of arg_grads + arg_grads->resize(arg_nodes.size()); + for (size_t i = 0; i < arg_nodes.size(); ++i) { + DataEntry odata(arg_nodes[i], 0); + auto it = grad_map.find(odata); + CHECK(it != grad_map.end()) << "bad graph"; + arg_grads->at(i) = it->second; + } +} } // namespace mxnet diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 86cf54feabfa..54a5fe9422b2 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -1,7 +1,7 @@ /*! * Copyright (c) 2015 by Contributors - * \file symbol.cc - * \brief symbol of mxnet + *\file symbol.cc + *\brief symbol of mxnet */ #include #include @@ -12,13 +12,13 @@ namespace mxnet { /*! - * \brief Node is represents node of an operator in the symbolic graph. + *\brief Node is represents node of an operator in the symbolic graph. * - * It stores connection to the inputs to function represented by OperatorProperty - * NOTE on data structure: there are three types of node: - * - Normal node: contains all the necessary elements of a graph. - * - OperatorProperty: the inputs_ is empty, represents an OperatorProperty that has not been applied. - * - Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed. + *It stores connection to the inputs to function represented by OperatorProperty + *NOTE on data structure: there are three types of node: + *- Normal node: contains all the necessary elements of a graph. + *- OperatorProperty: the inputs_ is empty, represents an OperatorProperty that has not been applied. + *- Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed. */ struct Symbol::Node { /*! \brief Operator of this node */ @@ -28,11 +28,11 @@ struct Symbol::Node { /*! \brief inputs to this node */ std::vector inputs; /*! - * \brief constructor - * \param op the OperatorProperty to construct the Node - * \param name the name of the symbol + *\brief constructor + *\param op the OperatorProperty to construct the Node + *\param name the name of the symbol */ - explicit Node(OperatorProperty* op = nullptr, const std::string& name = "") + explicit Node(OperatorProperty *op = nullptr, const std::string& name = "") : op(op), name(name) { } /*! \return Whether the symbol is atomic */ @@ -63,7 +63,7 @@ inline void Symbol::DFSVisit(FVisit fvisit) const { } } while (!stack.empty()) { - Node* back = stack.back(); + Node *back = stack.back(); stack.pop_back(); fvisit(back); for (auto it = back->inputs.rbegin(); it != back->inputs.rend(); ++it) { @@ -76,6 +76,28 @@ inline void Symbol::DFSVisit(FVisit fvisit) const { } } +// helper function to handle keyword argument mismatch +// throw approperiate messages +template +inline void KeywordArgumentMismatch(const char *source, + const TMap &kwargs, + const std::vector args) { + std::unordered_set keys(args.begin(), args.end()); + std::ostringstream head, msg; + msg << "\nCandidate arguments:\n"; + for (size_t i = 0; i < args.size(); ++i) { + msg << "\t[" << i << ']' << args[i] << '\n'; + } + + for (const auto& kv : kwargs) { + if (keys.count(kv.first) == 0) { + LOG(FATAL) << source + << "Keyword argument name " << kv.first << " not found." + << msg.str(); + } + } +} + int Symbol::FindDuplicateArgs(std::unordered_map *out) const { out->clear(); int max_dup = 1; @@ -328,19 +350,8 @@ void Symbol::Compose(const std::unordered_map& kwargs, } } if (nmatched != kwargs.size()) { - // Error message handling - std::vector req_args = this->ListArguments(); - std::unordered_set keys(req_args.begin(), req_args.end()); - std::ostringstream msg; - msg << "\nCandidate arguments:\n"; - for (size_t i = 0; i < req_args.size(); ++i) { - msg << "\t[" << i << ']' << req_args[i] << '\n'; - } - for (const auto& kv : kwargs) { - CHECK_NE(keys.count(kv.first), 0) - << "Keyword Argument " << kv.first << " not found in arguments." - << msg.str(); - } + KeywordArgumentMismatch( + "Symbol.Compose", kwargs, ListArguments()); } } @@ -358,11 +369,34 @@ Symbol Symbol::operator () (const std::unordered_map& kwarg return s; } -bool Symbol::InferShape(std::vector *in_shape, - std::vector *out_shape) const { +bool Symbol::InferShape(std::vector *arg_shapes, + std::vector *out_shapes) const { + StaticGraph g; + this->ToStaticGraph(&g); + return g.InferShape(arg_shapes, out_shapes); +} + +bool Symbol::InferShape(const std::unordered_map& known_arg_shapes, + std::vector *arg_shapes, + std::vector *out_shapes) const { StaticGraph g; this->ToStaticGraph(&g); - return g.InferShape(in_shape, out_shape); + arg_shapes->clear(); + arg_shapes->resize(g.arg_nodes.size(), TShape()); + size_t nmatched = 0; + for (size_t i = 0; i < g.arg_nodes.size(); ++i) { + const std::string& name = g.nodes[g.arg_nodes[i]].name; + auto it = known_arg_shapes.find(name); + if (it != known_arg_shapes.end()) { + arg_shapes->at(i) = it->second; + ++nmatched; + } + } + if (nmatched != known_arg_shapes.size()) { + KeywordArgumentMismatch( + "Symbol.InterShape", known_arg_shapes, ListArguments()); + } + return g.InferShape(arg_shapes, out_shapes); } Symbol Symbol::Create(OperatorProperty *op) { @@ -424,12 +458,12 @@ void Symbol::ToStaticGraph(StaticGraph *out_graph) const { } } // setup heads - out_graph->outputs.clear(); + out_graph->heads.clear(); for (auto &head : heads_) { StaticGraph::DataEntry e; e.source_id = node_index[head.source.get()]; e.index = head.index; - out_graph->outputs.push_back(e); + out_graph->heads.push_back(e); } } } // namespace mxnet