diff --git a/include/mxnet/atomic_symbol.h b/include/mxnet/atomic_symbol.h index 54f8223e80a3..bdeac84048a6 100644 --- a/include/mxnet/atomic_symbol.h +++ b/include/mxnet/atomic_symbol.h @@ -28,12 +28,12 @@ class AtomicSymbol { */ virtual ~AtomicSymbol() {} /*! \brief get the descriptions of inputs for this symbol */ - virtual std::vector DescribeArguments() const { + virtual std::vector ListArguments() const { // default implementation returns "data" return std::vector(1, std::string("data")); } /*! \brief get the descriptions of outputs for this symbol */ - virtual std::vector DescribeReturns() const { + virtual std::vector ListReturns() const { // default implementation returns "output" return std::vector(1, std::string("output")); } @@ -77,6 +77,13 @@ class AtomicSymbol { */ virtual std::string TypeString() const = 0; friend class Symbol; + + /*! + * \brief create atomic symbol by type name + * \param type_name the type string of the AtomicSymbol + * \return a new constructed AtomicSymbol + */ + static AtomicSymbol *Create(const char* type_name); }; } // namespace mxnet diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 4f0a9ea5f87a..9ad75b4e5954 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -209,13 +209,21 @@ MXNET_DLL int MXFuncInvoke(FunctionHandle fun, // Part 3: symbolic configuration generation //-------------------------------------------- /*! - * \brief create symbol from config - * \param cfg configuration string - * \param out created symbol handle + * \brief list all the available AtomicSymbolEntry + * \param out_size the size of returned array + * \param out_array the output AtomicSymbolCreator array * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymbolCreateFromConfig(const char *cfg, - SymbolHandle *out); +MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, + AtomicSymbolCreator **out_array); +/*! + * \brief Get the name of AtomicSymbol. + * \param creator the AtomicSymbolCreator + * \param out the returned name of the creator + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, + const char **out); /*! * \brief create Symbol by wrapping AtomicSymbol * \param creator the AtomicSymbolCreator @@ -231,50 +239,79 @@ MXNET_DLL int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator, const char **vals, SymbolHandle *out); /*! - * \brief free the symbol handle + * \brief Create a Variable Symbol. + * \param name name of the variable + * \param out pointer to the created symbol handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolCreateVariable(const char *name, SymbolHandle *out); +/*! + * \brief Create symbol from config. + * \param cfg configuration string + * \param out created symbol handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolCreateFromConfig(const char *cfg, + SymbolHandle *out); +/*! + * \brief Free the symbol handle. * \param symbol the symbol * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolFree(SymbolHandle symbol); /*! - * \brief list all the available AtomicSymbolEntry - * \param out_size the size of returned array - * \param out_array the output AtomicSymbolCreator array + * \brief Copy the symbol to another handle + * \param symbol the source symbol + * \param out used to hold the result of copy * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, - AtomicSymbolCreator **out_array); +MXNET_DLL int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out); /*! - * \brief get the singleton Symbol of the AtomicSymbol if any - * \param creator the AtomicSymbolCreator - * \param out the returned singleton Symbol of the AtomicSymbol the creator stands for + * \brief Print the content of symbol, used for debug. + * \param symbol the symbol + * \param out_str pointer to hold the output string of the printing. * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymbolGetSingleton(AtomicSymbolCreator creator, - SymbolHandle *out); +MXNET_DLL int MXSymbolPrint(SymbolHandle symbol, const char **out_str); /*! - * \brief get the singleton Symbol of the AtomicSymbol if any - * \param creator the AtomicSymbolCreator - * \param out the returned name of the creator + * \brief List arguments in the symbol. + * \param symbol the symbol + * \param out_size output size + * \param out_str_array pointer to hold the output string array * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, - const char **out); +MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol, + mx_uint *out_size, + const char ***out_str_array); /*! - * \brief compose the symbol on other symbol + * \brief List returns in the symbol. + * \param symbol the symbol + * \param out_size output size + * \param out_str_array pointer to hold the output string array + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolListReturns(SymbolHandle symbol, + mx_uint *out_size, + const char ***out_str_array); +/*! + * \brief Compose the symbol on other symbols. + * + * This function will change the sym hanlde. + * To achieve function apply behavior, copy the symbol first + * before apply. + * * \param sym the symbol to apply + * \param name the name of symbol * \param num_args number of arguments * \param keys the key of keyword args (optional) * \param args arguments to sym - * \param out the resulting symbol * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolCompose(SymbolHandle sym, + const char *name, mx_uint num_args, const char** keys, - SymbolHandle* args, - SymbolHandle* out); - + SymbolHandle* args); //-------------------------------------------- // Part 4: operator interface on NArray //-------------------------------------------- diff --git a/include/mxnet/registry.h b/include/mxnet/registry.h index df9c27b9a4ad..dcc87b6ee232 100644 --- a/include/mxnet/registry.h +++ b/include/mxnet/registry.h @@ -229,18 +229,9 @@ struct AtomicSymbolEntry { std::string name; /*! \brief function body to create AtomicSymbol */ Creator body; - /*! \brief singleton is created when no param is needed for the AtomicSymbol */ - Symbol *singleton_symbol; /*! \brief constructor */ explicit AtomicSymbolEntry(const std::string& name) - : use_param(true), name(name), body(NULL), singleton_symbol(NULL) {} - /*! - * \brief set if param is needed by this AtomicSymbol - */ - inline AtomicSymbolEntry &set_use_param(bool use_param) { - this->use_param = use_param; - return *this; - } + : use_param(true), name(name), body(NULL) {} /*! * \brief set the function body */ @@ -248,12 +239,6 @@ struct AtomicSymbolEntry { this->body = body; return *this; } - /*! - * \brief return the singleton symbol - */ - Symbol *GetSingletonSymbol(); - /*! \brief destructor */ - ~AtomicSymbolEntry(); /*! * \brief invoke the function * \return the created AtomicSymbol diff --git a/include/mxnet/static_graph.h b/include/mxnet/static_graph.h index e06cbf4f9c42..4aadbe2fb631 100644 --- a/include/mxnet/static_graph.h +++ b/include/mxnet/static_graph.h @@ -21,39 +21,28 @@ struct StaticGraph { /*! \brief Node in static graph */ struct StaticNode { /*! \brief wrapped atomic symbol */ - AtomicSymbol* sym_; + std::unique_ptr sym; /*! \brief name of the node */ - std::string name_; + std::string name; + /*! \brief index of output from the source. */ + int index; + /*! \brief output shape for node */ + std::vector in_shape; + /*! \brief output shape for node */ + std::vector out_shape; + /*! \brief input id for each node */ + std::vector inputs_index; + /*! \brief output id for each node */ + std::vector outputs_index; }; + /*! \brief head node (need input from outside) */ + std::vector in_args_node_id; + /*! \brief tail node (generate data to outside) */ + std::vector return_node_id; /*! \brief node name to id dictionary */ std::unordered_map name_id_map; /*! \brief all nodes in the graph */ std::vector nodes; - /*! \brief output id for each node */ - std::vector > output_index; - /*! \brief connected graph for each node */ - std::vector > connected_graph; - /*! \brief find node by using name - * \param name node name - * \param sym symbol need to be copied into node - * \return node id - */ - int FindNodeByName(const std::string& name, const AtomicSymbol* sym) { - int id = 0; - if (name_id_map.find(name) == name_id_map.end()) { - name_id_map[name] = name_id_map.size(); - StaticNode static_node; - static_node.sym_ = sym->Copy(); - static_node.name_ = name; - nodes.push_back(static_node); - output_index.push_back(std::vector()); - connected_graph.push_back(std::vector()); - id = name_id_map.size(); - } else { - id = name_id_map[name]; - } - return id; - } }; } // namespace mxnet #endif // MXNET_STATIC_GRAPH_H_ diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index df1e78438560..e7869e6b89d2 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -1,145 +1,272 @@ /*! * Copyright (c) 2015 by Contributors * \file symbol.h - * \brief symbol interface of mxnet + * \brief symbolic interface of mxnet */ #ifndef MXNET_SYMBOL_H_ #define MXNET_SYMBOL_H_ #include -#include +#include #include #include +#include #include +#include #include #include +#include #include "./base.h" #include "./tensor_blob.h" #include "./operator.h" #include "./static_graph.h" namespace mxnet { -class CompositeOperator; /*! - * \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol - * should support expressions and often passed by value. While AtomicSymbol have many subclasses, - * passing by value would result in object slicing. + * \brief Symbol is used to represent dynamically generated symbolic computation graph. * - * Symbol is always composite, the head Node is the output node of the symbol. - * A atomic symbol can be seen as a special case of the composite symbol with only the head node. + * This class is used as a tool to generate computation graphs(aka. configuration) of the network. + * Symbol is always composite, the head Node is the output node of the symbol. + * An atomic symbol can be seen as a special case of the composite symbol with only the head node. + * + * The symbol can be converted from/to StaticGraph, the actual configuration used by mxnet. + * Symbol offers more flexible way to composite nodes than StaticGraph, which makes it good + * tool to generate configurations from language bindings such as python. + * \sa StaticGraph */ class Symbol { public: /*! - * \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol - * with input symbols. + * \brief copy the symbol + * \return a deep copy of the graph */ - struct Node { - /*! \brief wrapped atomic symbol */ - AtomicSymbol* sym_; - /*! \brief name of the node */ - std::string name_; - /*! \brief inputs to this node */ - std::vector > in_symbol_; - /*! \brief index of the inputs if the inputs are tuple */ - std::vector in_index_; - /*! \brief the output shape of the wrapped symbol */ - std::vector out_shape_; - /*! - * \brief constructor - */ - explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") : - sym_(sym), name_(name) { - } - /*! - * \brief destructor - */ - ~Node() { - if (sym_) { - delete sym_; - } - } - }; - - protected: - /*! \brief the head node of the Symbol, it could be shared in many graphs */ - std::shared_ptr head_; - /*! \brief if the head has multiple return values, index is used to specify */ - int index_; - /*! \brief find the nodes that use placeholder arguments */ - std::shared_ptr > > arg_users_; - /*! \brief find arg users */ - void FindArgUsers(); - /** - * @brief Recursively parse the symbol to equivalent static graph. + Symbol Copy() const; + /*! + * \brief print the symbol info to output stream. + * \param os the output stream we like to print to + */ + void Print(std::ostream &os) const; // NOLINT(*) + /*! + * \brief List the arguments names. * - * @param node The current node in dfs - * @param graph The static graph + * The position of the returned list also corresponds to calling position in operator() + * \return the arguments list of this symbol, they can be either named or unnamed (empty string). */ - void Dfs(const std::shared_ptr node, StaticGraph *graph); - - public: + std::vector ListArguments() const; + /*! \return get the descriptions of outputs for this symbol */ + std::vector ListReturns() const; /*! - * \brief declare virtual destructor in case it is subclassed. + * \brief get the index th element from the returned tuple. + * \param index index of multi output + * \return the symbol corresponds to the indexed element. */ - virtual ~Symbol() {} + Symbol operator[] (int index) const; /*! - * \brief bind to device and returns an operator. - * \param ctx context of the operator - * \return returns the pointer to a created operator. It is on the user to delete. + * \brief infer the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by ListArguments + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + * \return if the shape inference is successful, return true, else return false. */ - virtual CompositeOperator* Bind(Context ctx) const { return nullptr; } - /** - * \brief Bind the symbol to a composite operator - * \param ctx context of the operator - * \param in A map denotes name and corresponding NArray for binding - * \return The composite operator + bool InferShape(std::vector *in_shape, + std::vector *out_shape); + /*! + * \brief Compose the symbol with arguments, this changes current symbol. + * + * The positional arguments passed in must be complete(contain all arguments). + * + * \param args positional arguments for the symbol + * \param name name of returned symbol. */ - virtual CompositeOperator* Bind(Context ctx, const std::unordered_map& in); + void Compose(const std::vector& args, + const std::string& name); /*! - * \brief copy the symbol - * \return a deep copy of the graph + * \brief Compose the symbol with arguments, this changes the current symbol. + * The kwargs passed in can be in-complete, + * + * The rest of the symbols will remain the same name. + * + * \param kwargs keyword arguments for the symbol + * \param name name of returned symbol. */ - virtual Symbol Copy() const; + void Compose(const std::unordered_map& kwargs, + const std::string& name); /*! - * \brief compose with arguments + * \brief Apply the symbol as a function, compose with arguments * \param args positional arguments for the symbol + * \param name name of returned symbol. * \return a new Symbol which is the composition of current symbol with its arguments */ - virtual Symbol operator () (const std::vector& args) const; + inline Symbol operator () (const std::vector& args, + const std::string& name) const { + Symbol s = this->Copy(); + s.Compose(args, name); + return s; + } /*! * \brief compose with named arguments * \param kwargs keyword arguments for the symbol + * \param name name of returned symbol. * \return a new symbol which is the composition of current symbol with its arguments */ - virtual Symbol operator () (const std::unordered_map& kwargs) const; + inline Symbol operator () (const std::unordered_map& kwargs, + const std::string& name) { + Symbol s = this->Copy(); + s.Compose(kwargs, name); + return s; + } /*! - * \brief get the index th element from the returned tuple. + * \brief create Symbol by wrapping AtomicSymbol + * This function takes the ownership of atomic_symbol. + * + * \param atomic_symbol the AtomicSymbol + * \return Symbol + * \sa AtomicSymbol::Create */ - virtual Symbol operator[] (int index) const; + static Symbol Create(AtomicSymbol *atomic_symbol); /*! - * \brief arguments information - * \return the arguments list of this symbol, they can be either named or unnamed (empty string). + * \brief create equivalence of symbols from static graphs + * \param graph the static graph + * \return list of Symbols representing outputs of the graph */ - virtual std::vector ListArgs(); - /** - * @brief Convert current symbol to its equivalent static graph representation. - * @return the static graph + static std::vector Create(const StaticGraph &graph); + /*! + * \brief Convert a list of symbols into static graph + * + * The user can go further to call bind function on static graph + * + * \param heads the heads of the graph + * \param out_graph the pointer holder of the output graph */ - virtual StaticGraph ToStaticGraph(); + static void Convert(const std::vector &heads, StaticGraph *out_graph); /*! - * \brief create Symbol by wrapping AtomicSymbol + * \brief create variable symbol node + * \param name name of the variable + * \return the new variable */ - static Symbol Create(AtomicSymbol* atomic_symbol); + inline static Symbol CreateVariable(const std::string &name) { + Symbol s; + s.head_ = DataEntry(std::make_shared(nullptr, name), 0); + return std::move(s); + } + + protected: + // forward declare Node + struct Node; + /*! \brief an entry that represents output data from a node */ + struct DataEntry { + /*! \brief the source node of this data */ + std::shared_ptr source; + /*! + * \brief index of output from the source. + * If index == -1, it represents all the outputs. + */ + int index; + /*! \brief enabled default copy constructor */ + DataEntry() {} + /*! \brief constructor from index */ + DataEntry(std::shared_ptr source, int index) + : source(source), index(index) {} + }; /*! - * \brief create atomic symbol wrapped in symbol - * \param type_name the type string of the AtomicSymbol - * \param param the parameter stored as key value pairs - * \return the constructed Symbol + * \brief Node is represents node of an operator in the symbolic graph. + * + * It stores connection to the inputs to function represented by AtomicSymbol + * NOTE on data structure: there are three types of node: + * - Normal node: contains all the necessary elements of a graph. + * - AtomicSymbol: the inputs_ is empty, represents an AtomicSymbol that has not been applied. + * - Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed. */ - static Symbol Create(const std::string& type_name, - const std::vector >& param); -}; + struct Node { + /*! \brief wrapped atomic symbol */ + std::unique_ptr sym; + /*! \brief name of the node */ + std::string name; + /*! \brief inputs to this node */ + std::vector inputs; + /*! + * \brief constructor + * \param sym the AtomicSymbol to construct the symbol + * \param name the name of the symbol + */ + explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") + : sym(sym), name(name) { + } + /*! \return Whether the symbol is AtomicSymbol */ + inline bool is_atomic() const { + return inputs.size() == 0 && sym != nullptr; + } + /*! \return Whetehr the symbolc is a PlaceHolder */ + inline bool is_variable() const { + return sym == nullptr; + } + }; + /*! \brief the head node of the Symbol */ + DataEntry head_; + private: + /*! \brief DFS Visit for symbol with single head + * This function call is specail case for DFSVisit_ + * \param fvisit function applied for each visit. + * \tparam FVisit visiting function type + */ + template + inline void DFSVisit(FVisit fvisit) const { + std::vector tmp = { head_ }; + DFSVisit(tmp, fvisit); + } + /*! + * \brief Visit all the nodes in left-to-right depth first order. + * + * This function will visit the graph in DFS order, call fvisit exactly once + * for each Node, and store the result in out_result. + * + * \param fvisit function applied for each visit. + * \tparam FVisit visiting function type + */ + template + static inline void DFSVisit(const std::vector &heads, + FVisit fvisit) { + std::vector stack; + std::unordered_set visited; + // put the head into the graph + for (auto &head : heads) { + Node *ptr = head.source.get(); + stack.push_back(ptr); + visited.insert(ptr); + } + while (!stack.empty()) { + Node* back = stack.back(); + stack.pop_back(); + fvisit(back); + for (auto it = back->inputs.rbegin(); it != back->inputs.rend(); ++it) { + Node *ptr = it->source.get(); + if (visited.count(ptr) == 0) { + stack.push_back(ptr); + visited.insert(ptr); + } + } + } + } + /*! \brief Toposort the symbol + * \prarm heads symbol's head + * \prarm ret sorted nodes + */ + static inline void Toposort(const std::vector &heads, + std::vector *ret); + /*! + * \brief Find duplicate arguments in the composition + * \param out the map of argument-name -> occurence count + * \return maximum number of duplication factor + */ + int FindDuplicateArgs(std::unordered_map *out) const; +}; } // namespace mxnet #endif // MXNET_SYMBOL_H_ diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 6f4146d162e3..031b18ab862f 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -1,4 +1,5 @@ # coding: utf-8 +# pylint: disable=invalid-name, protected-access """Symbol support of mxnet""" from __future__ import absolute_import @@ -37,37 +38,113 @@ def __init__(self, handle): """ self.handle = handle + def __del__(self): + check_call(_LIB.MXSymbolFree(self.handle)) + + def __copy__(self): + return self.__deepcopy__() + + def __deepcopy__(self): + handle = SymbolHandle() + check_call(_LIB.MXSymbolCopy(self.handle, + ctypes.byref(handle))) + return Symbol(handle) + def __call__(self, *args, **kwargs): - """Compose Symbols + """Invoke symbol as function on inputs. + + Parameters + ---------- + args: + provide positional arguments + + kwargs: + provide keyword arguments + Returns + ------- + the resulting symbol + """ + s = self.__deepcopy__() + s._compose(*args, **kwargs) + return s + + def _compose(self, *args, **kwargs): + """Compose symbol on inputs. + + This call mutates the current symbol. Parameters ---------- args: provide positional arguments + kwargs: provide keyword arguments Returns ------- the resulting symbol """ - assert (len(args) == 0 or len(kwargs) == 0) + name = kwargs.pop('name', None) + if name: + name = c_str(name) + if len(args) != 0 and len(kwargs) != 0: + raise TypeError('compose only accept input Symbols \ + either as positional or keyword arguments, not both') + for arg in args: - assert isinstance(arg, Symbol) - for _, val in kwargs: - assert isinstance(val, Symbol) + if not isinstance(arg, Symbol): + raise TypeError('Compose expect `Symbol` as arguments') + for _, val in kwargs.items(): + if not isinstance(val, Symbol): + raise TypeError('Compose expect `Symbol` as arguments') + num_args = len(args) + len(kwargs) if len(kwargs) != 0: keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) - args = c_array(SymbolHandle, kwargs.values()) + args = c_array(SymbolHandle, [s.handle for s in kwargs.values()]) else: keys = None - args = c_array(SymbolHandle, args) - - out = SymbolHandle() - check_call(_LIB.MXSymbolCompose( - self.handle, - num_args, - keys, - args, - ctypes.byref(out))) - return Symbol(out) + args = c_array(SymbolHandle, [s.handle for s in args]) + check_call(_LIB.MXSymbolCompose( \ + self.handle, name, num_args, keys, args)) + + def list_arguments(self): + """List all the arguments in the symbol. + + Returns + ------- + args : list of string + List of all the arguments. + """ + size = ctypes.c_uint() + sarr = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.MXSymbolListArguments( \ + self.handle, ctypes.byref(size), ctypes.byref(sarr))) + return [sarr[i] for i in range(size.value)] + + def list_returns(self): + """List all returns in the symbol. + + Returns + ------- + args: list of string + List of all the returns. + """ + size = ctypes.c_uint() + sarr = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.MXSymbolListReturns( \ + self.handle, ctypes.byref(size), ctypes.byref(sarr))) + return [sarr[i] for i in range(size.value)] + + def debug_str(self): + """Get a debug string. + + Returns + ------- + debug_str : string + Debug string of the symbol. + """ + debug_str = ctypes.c_char_p() + check_call(_LIB.MXSymbolPrint( \ + self.handle, ctypes.byref(debug_str))) + return debug_str.value diff --git a/python/mxnet/symbol_creator.py b/python/mxnet/symbol_creator.py index ee8f8bba1525..a0617ce395f8 100644 --- a/python/mxnet/symbol_creator.py +++ b/python/mxnet/symbol_creator.py @@ -1,11 +1,12 @@ # coding: utf-8 +# pylint: disable=invalid-name, protected-access, no-self-use """Symbol support of mxnet""" from __future__ import absolute_import import ctypes from .base import _LIB -from .base import c_array, c_str -from .base import mx_uint, SymbolHandle +from .base import c_array, c_str, string_types +from .base import SymbolHandle from .base import check_call from .symbol import Symbol @@ -25,34 +26,53 @@ def __init__(self, name, handle): """ self.name = name self.handle = handle - singleton_ = SymbolHandle() - check_call(_LIB.MXSymbolGetSingleton(self.handle, ctypes.byref(singleton_))) - if singleton_: - self.singleton = Symbol(singleton_) - else: - self.singleton = None - - def __call__(self, **kwargs): + + def __call__(self, *args, **kwargs): """Invoke creator of symbol by passing kwargs Parameters ---------- + name : string + Name of the resulting symbol. + + *args + Positional arguments + **kwargs - provide the params necessary for the symbol creation + Provide the params necessary for the symbol creation. + Returns ------- the resulting symbol """ - keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) - vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()]) + param_keys = [] + param_vals = [] + symbol_kwargs = {} + name = kwargs.pop('name', None) + + for k, v in kwargs.items(): + if isinstance(v, Symbol): + symbol_kwargs[k] = v + else: + param_keys.append(k) + param_vals.append(c_str(str(v))) + + # create atomic symbol + param_keys = c_array(ctypes.c_char_p, param_keys) + param_vals = c_array(ctypes.c_char_p, param_vals) sym_handle = SymbolHandle() - check_call(_LIB.MXSymbolCreateFromAtomicSymbol( - self.handle, - mx_uint(len(kwargs)), - keys, - vals, - ctypes.byref(sym_handle))) - return Symbol(sym_handle) + check_call(_LIB.MXSymbolCreateFromAtomicSymbol( \ + self.handle, len(param_keys), \ + param_keys, param_vals, \ + ctypes.byref(sym_handle))) + + if len(args) != 0 and len(symbol_kwargs) != 0: + raise TypeError('%s can only accept input \ + Symbols either as positional or keyword arguments, not both' % self.name) + + s = Symbol(sym_handle) + s._compose(*args, name=name, **symbol_kwargs) + return s class _SymbolCreatorRegistry(object): """Function Registry""" @@ -62,8 +82,27 @@ def __init__(self): check_call(_LIB.MXSymbolListAtomicSymbolCreators(ctypes.byref(size), ctypes.byref(plist))) hmap = {} - name = ctypes.c_char_p() for i in range(size.value): - name = _LIB.MXSymbolGetAtomicSymbolName(plist[i], ctypes.byref(name)) - hmap[name] = _SymbolCreator(name, plist[i]) + name = ctypes.c_char_p() + check_call(_LIB.MXSymbolGetAtomicSymbolName(plist[i], ctypes.byref(name))) + hmap[name.value] = _SymbolCreator(name, plist[i]) self.__dict__.update(hmap) + + def Variable(self, name): + """Create a symbolic variable with specified name. + + Parameters + ---------- + name : str + Name of the variable. + + Returns + ------- + variable : Symbol + The created variable symbol. + """ + if not isinstance(name, string_types): + raise TypeError('Expect a string for variable `name`') + handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateVariable(name, ctypes.byref(handle))) + return Symbol(handle) diff --git a/python/test_python.py b/python/test_python.py index 5fe4a03ac11a..7aa4c432f1db 100644 --- a/python/test_python.py +++ b/python/test_python.py @@ -1,4 +1,4 @@ -#pylint: skip-file +# pylint: skip-file import mxnet as mx a = mx.narray.create((3000, 4000)) diff --git a/python/test_symbol.py b/python/test_symbol.py new file mode 100644 index 000000000000..f4823a087e3c --- /dev/null +++ b/python/test_symbol.py @@ -0,0 +1,23 @@ +# pylint: skip-file +import mxnet as mx + +data = mx.sym.Variable('data') +print data.debug_str() + +fc1 = mx.sym.FullyConnected(data=data, name='fc1', no_bias=0) +fc2 = mx.sym.FullyConnected(data=fc1, name='fc2', no_bias=0) + +print fc2.debug_str() + +print fc2.list_arguments() + +fc3 = mx.sym.FullyConnected(name='fc3') +fc4 = mx.sym.FullyConnected(data=fc3, name='fc4') + +print fc4.debug_str() + +print "-" * 10 +composed_fc4 = fc4(fc3_data=fc2, name='composed') +print composed_fc4.debug_str() + + diff --git a/src/c_api.cc b/src/c_api.cc index 5452c1be2e3d..df8eb349752c 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -12,6 +12,7 @@ #include #include #include +#include // macro hanlding for threadlocal variables #ifdef __GNUC__ @@ -26,6 +27,18 @@ #message("Warning: Threadlocal is not enabled"); #endif +/*! \brief symbol wrapper to easily hold returning information */ +struct MXAPISymbolWrapper { + /*! \brief the actual symbol */ + mxnet::Symbol sym; + /*! \brief result holder for returning string */ + std::string ret_str; + /*! \brief result holder for returning strings */ + std::vector ret_vec_str; + /*! \brief result holder for returning string pointers */ + std::vector ret_vec_charp; +}; + /*! * \brief helper to store error message in threadlocal storage */ @@ -86,8 +99,15 @@ using namespace mxnet; /*! \brief macro to guard beginning and end section of all functions */ #define API_BEGIN() try { -/*! \brief every function starts with API_BEGIN(); and finishes with API_END(); */ +/*! \brief every function starts with API_BEGIN(); + and finishes with API_END() or API_END_HANDLE_ERROR */ #define API_END() } catch(dmlc::Error &e) { return MXHandleException(e); } return 0; +/*! + * \brief every function starts with API_BEGIN(); + * and finishes with API_END() or API_END_HANDLE_ERROR + * The finally clause contains procedure to cleanup states when an error happens. + */ +#define API_END_HANDLE_ERROR(Finalize) } catch(dmlc::Error &e) { Finalize; return MXHandleException(e); } return 0; // NOLINT(*) /*! \brief return str message of the last error */ const char *MXGetLastError() { @@ -249,74 +269,135 @@ int MXFuncInvoke(FunctionHandle fun, API_END(); } +//-------------------------------------------- +// Part 3: symbolic configuration generation +//-------------------------------------------- + +int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, + AtomicSymbolCreator **out_array) { + API_BEGIN(); + auto &vec = Registry::List(); + *out_size = static_cast(vec.size()); + *out_array = (AtomicSymbolCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) + API_END(); +} + +int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, + const char **out) { + API_BEGIN(); + AtomicSymbolEntry *e = static_cast(creator); + *out = e->name.c_str(); + API_END(); +} + int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator, int num_param, const char **keys, const char **vals, SymbolHandle *out) { + MXAPISymbolWrapper *s = new MXAPISymbolWrapper(); + AtomicSymbol *atomic_symbol = nullptr; + API_BEGIN(); AtomicSymbolEntry *e = static_cast(creator); - *out = static_cast(new Symbol); - AtomicSymbol *atomic_symbol = (*e)(); + atomic_symbol = (*e)(); for (int i = 0; i < num_param; ++i) { atomic_symbol->SetParam(keys[i], vals[i]); } - *static_cast(*out) = Symbol::Create(atomic_symbol); - API_END(); + s->sym = Symbol::Create(atomic_symbol); + *out = s; + API_END_HANDLE_ERROR(delete s; delete atomic_symbol); +} + +int MXSymbolCreateVariable(const char *name, SymbolHandle *out) { + MXAPISymbolWrapper *s = new MXAPISymbolWrapper(); + API_BEGIN(); + s->sym = Symbol::CreateVariable(name); + *out = s; + API_END_HANDLE_ERROR(delete s); } int MXSymbolFree(SymbolHandle symbol) { API_BEGIN(); - delete static_cast(symbol); + delete static_cast(symbol); API_END(); } -int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, - AtomicSymbolCreator **out_array) { +int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out) { + MXAPISymbolWrapper *s = new MXAPISymbolWrapper(); + API_BEGIN(); - auto &vec = Registry::List(); - *out_size = static_cast(vec.size()); - *out_array = (AtomicSymbolCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) + s->sym = (static_cast(symbol)->sym).Copy(); + *out = s; + API_END_HANDLE_ERROR(delete s); +} + +int MXSymbolPrint(SymbolHandle symbol, const char **out_str) { + MXAPISymbolWrapper *s = static_cast(symbol); + + API_BEGIN(); + std::ostringstream os; + (s->sym).Print(os); + s->ret_str = os.str(); + *out_str = (s->ret_str).c_str(); API_END(); } -int MXSymbolGetSingleton(AtomicSymbolCreator creator, - SymbolHandle *out) { +int MXSymbolListArguments(SymbolHandle symbol, + mx_uint *out_size, + const char ***out_str_array) { + MXAPISymbolWrapper *s = static_cast(symbol); API_BEGIN(); - AtomicSymbolEntry *e = static_cast(creator); - *out = static_cast(e->GetSingletonSymbol()); + if (s->ret_vec_charp.size() == 0) { + s->ret_vec_str = std::move((s->sym).ListArguments()); + for (size_t i = 0; i < s->ret_vec_str.size(); ++i) { + s->ret_vec_charp.push_back(s->ret_vec_str[i].c_str()); + } + } + *out_size = static_cast(s->ret_vec_charp.size()); + *out_str_array = dmlc::BeginPtr(s->ret_vec_charp); API_END(); } -int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, - const char **out) { +int MXSymbolListReturns(SymbolHandle symbol, + mx_uint *out_size, + const char ***out_str_array) { + MXAPISymbolWrapper *s = static_cast(symbol); API_BEGIN(); - AtomicSymbolEntry *e = static_cast(creator); - *out = e->name.c_str(); + s->ret_vec_charp.clear(); + if (s->ret_vec_charp.size() == 0) { + s->ret_vec_str = std::move((s->sym).ListReturns()); + for (size_t i = 0; i < s->ret_vec_str.size(); ++i) { + s->ret_vec_charp.push_back(s->ret_vec_str[i].c_str()); + } + } + *out_size = static_cast(s->ret_vec_charp.size()); + *out_str_array = dmlc::BeginPtr(s->ret_vec_charp); API_END(); } int MXSymbolCompose(SymbolHandle sym, + const char *name, mx_uint num_args, const char** keys, - SymbolHandle* args, - SymbolHandle* out) { + SymbolHandle* args) { API_BEGIN(); - const Symbol* s = static_cast(sym); - Symbol* ret = new Symbol; - if (keys == NULL) { + std::string s_name; + if (name != nullptr) s_name = name; + + MXAPISymbolWrapper* s = static_cast(sym); + if (keys == nullptr && num_args != 0) { std::vector pos_args; for (mx_uint i = 0; i < num_args; ++i) { - pos_args.push_back(*(Symbol*)(args[i])); // NOLINT(*) + pos_args.push_back(((MXAPISymbolWrapper*)(args[i]))->sym); // NOLINT(*) } - *ret = (*s)(pos_args); + (s->sym).Compose(pos_args, s_name); } else { std::unordered_map kwargs; for (mx_uint i = 0; i < num_args; ++i) { - kwargs[keys[i]] = *(Symbol*)(args[i]); // NOLINT(*) + kwargs[keys[i]] = ((MXAPISymbolWrapper*)(args[i]))->sym; // NOLINT(*) } - *ret = (*s)(kwargs); + (s->sym).Compose(kwargs, s_name); } - *out = ret; API_END(); } diff --git a/src/operator/composite_operator.h b/src/operator/composite_operator.h new file mode 100644 index 000000000000..ddaa0a50f561 --- /dev/null +++ b/src/operator/composite_operator.h @@ -0,0 +1,102 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file composite_operator.h + * \brief composite operator of mxnet + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_COMPOSITE_OPERATOR_H_ +#define MXNET_OPERATOR_COMPOSITE_OPERATOR_H_ +#include +#include +#include +#include "./atomic_symbol.h" +#include "./base.h" +#include "./static_graph.h" +#include "./static_operator.h" + +namespace mxnet { +/*! + * \brief composite_operator interface + * composite operator is a combination of static operator from static graph + */ +class CompositeOperator : public Operator { + public: + /*! \brief destructor */ + virtual ~CompositeOperator() {} + /*! + * \brief describe property of op + * \return a bit map in int + */ + virtual int DescribeProperty() const { + // default most of layer only conatin internal state + return kContainInteralState; + } + /*! \brief Make operator by using graph + * \param ctx ctx context of the created operator + * \param in input narray + * \param graph input static graph + */ + void Bind(Context ctx, + const std::vector &in, + std::shared_ptr graph); + /*! + * \brief perform a forward operation of operator, save the output to NArray + * This method only pushes an execution request to the DAG engine, and + * return immediately. Actual execution is conducted by the DAG engine. + * \param opt option on Forward such as whether this is training phase + * \param ctx runtime context + * \param in_data array of input data, it is const + * \param out_data array of output data, + * the space of NArray in out_data must be pre-allocated with InferShape + * \sa NArray + */ + virtual void Forward(Option opt, + RunContext ctx, + const std::vector &in_data, + const std::vector &out_data); + /*! + * \brief perform a backward operation of the operator to get the gradient + * This method only pushes an execution request to the DAG engine, and + * return immediately. Actual execution is conducted by the DAG engine. + * \param ctx runtime context + * \param grad_next the gradient value of the output of the operator, used by chain rule. + * \param in_data the array of input data + * \param out_grad array of output gradient + * \param req request types of the gradient saving operation + * only inplace will change input data + * \sa GradReqType, NArray + */ + virtual void Backward(RunContext ctx, + const std::vector &grad_next, + const std::vector &in_data, + const std::vector &out_grad, + const std::vector &req); + /*! + * \brief perform an extraction operation to get feature map + * \param name of symbol need to be extracted + * \return empty narray for invalid name or narray of the feature map + */ + virtual NArray Extract(const std::string &symbol_name); + + private: + /*! \brief + struct Connection { + + }; + /*! \brief static operators for each node */ + std::vector > static_ops_; + /*! \brief feature map for each op */ + std::vector > feature_maps_; + /*! \brief input NArray link */ + std::vector > in_data_; + /*! \brief input NArray gradient */ + std::vector > in_grad_; + /*! \brief output NArray link */ + std::vector > out_data_; + /*! \brief static graph */ + std::shared_ptr graph_; +}; // class CompositeOperator +} // namespace mxnet +#endif // MXNET_OPERATOR_COMPOSITE_OPERATOR_H_ + + diff --git a/src/registry.cc b/src/registry.cc index d51907dbdcd7..f3ce0bd28ff0 100644 --- a/src/registry.cc +++ b/src/registry.cc @@ -30,24 +30,6 @@ template NArrayFunctionEntry &Registry::Register(const std: template Registry *Registry::Get(); #endif -Symbol *AtomicSymbolEntry::GetSingletonSymbol() { - if (singleton_symbol) { - return singleton_symbol; - } else if (body && !use_param) { - singleton_symbol = new Symbol; - *singleton_symbol = Symbol::Create(body()); - return singleton_symbol; - } else { - return NULL; - } -} - -AtomicSymbolEntry::~AtomicSymbolEntry() { - if (singleton_symbol) { - delete singleton_symbol; - } -} - template AtomicSymbolEntry &Registry::Register(const std::string& name); template Registry *Registry::Get(); diff --git a/src/static_operator/fully_connect_op-inl.h b/src/static_operator/fully_connect_op-inl.h index 062e17c8ea98..15f8e857d3cf 100644 --- a/src/static_operator/fully_connect_op-inl.h +++ b/src/static_operator/fully_connect_op-inl.h @@ -95,7 +95,7 @@ class FullyConnectOp : public StaticOperator { */ class FullyConnectSymbol : public AtomicSymbol { public: - virtual std::vector DescribeArguments() const { + virtual std::vector ListArguments() const { std::string ret[] = {"data", "weight", "bias"}; if (param_.no_bias == 0) { return std::vector(ret, ret + 3); diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index a81b7ce0cccd..7a1e932ca872 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -7,181 +7,364 @@ #include #include #include -#include +#include +#include +#include +#include namespace mxnet { - -void Symbol::FindArgUsers() { - arg_users_.reset(new std::vector >); - // depth first traversing - std::vector > stk; - stk.push_back({head_.get(), 0}); - while (!stk.empty()) { - std::pair& back = stk.back(); - if (back.first->in_symbol_.size() == back.second) { - stk.pop_back(); - } else { - Node* next_level = back.first->in_symbol_[back.second].get(); - if (next_level->sym_) { - stk.push_back({next_level, 0}); - } else { // back uses next_level which is a placeholder - arg_users_->push_back({back.first, back.second}); - } - back.second += 1; - } - } -} - +// copy the symbol Symbol Symbol::Copy() const { - Symbol s; std::unordered_map > old_new; - std::vector stk; - stk.push_back(head_.get()); - // copy nodes - while (!stk.empty()) { - Node* back = stk.back(); - stk.pop_back(); - if (old_new.count(back) == 0) { - if (back->sym_) { - old_new[back] = std::make_shared(back->sym_->Copy(), back->name_); + // use DFSVisit to copy all the nodes + this->DFSVisit([&old_new](Node *node) { + if (node->sym == nullptr) { + old_new[node] = std::make_shared(nullptr, node->name); } else { - old_new[back] = std::make_shared(nullptr, back->name_); - } - } - for (const std::shared_ptr& n : back->in_symbol_) { - if (old_new.count(n.get()) == 0) { - stk.push_back(n.get()); + old_new[node] = std::make_shared(node->sym->Copy(), node->name); } + }); + // connect nodes of new graph + for (const auto &kv : old_new) { + for (const DataEntry& n : kv.first->inputs) { + Node *ptr = n.source.get(); + kv.second->inputs.push_back(DataEntry(old_new[ptr], n.index)); } } - // connect nodes - for (auto kv : old_new) { - for (const std::shared_ptr& n : kv.first->in_symbol_) { - kv.second->in_symbol_.push_back(old_new[n.get()]); - } - } - s.head_ = old_new[this->head_.get()]; - // copy arg_users_ - if (arg_users_) { - s.arg_users_.reset(new std::vector >); - std::transform(arg_users_->begin(), arg_users_->end(), std::back_inserter(*s.arg_users_), - [&old_new](const std::pair& n) -> std::pair { - return { old_new[n.first].get(), n.second }; - }); - } + // set the head + Symbol s; + s.head_ = DataEntry(old_new[head_.source.get()], head_.index); return s; } -Symbol Symbol::operator () (const std::vector& args) const { - Symbol s = this->Copy(); - if (!s.arg_users_) { // if arg_users_ has not been populated - s.FindArgUsers(); +void Symbol::Print(std::ostream &os) const { + if (head_.source->is_atomic()) { + os << "AtomicSymbol "<< " Type:" << head_.source->sym->TypeString() << '\n' + << "Inputs:"; + std::vector args = this->ListArguments(); + for (size_t i = 0; i < args.size(); ++i) { + os << "\targ[" << i << "]=" << args[i] << "\n"; + } + } else { + // use DFSVisit to copy all the nodes + this->DFSVisit([&os](Node *node) { + if (node->is_variable()) { + os << "Variable:" << node->name << '\n'; + } else { + os << "Name: " << node->name << " Type:" << node->sym->TypeString() << '\n' + << "Inputs:\n"; + for (size_t i = 0; i < node->inputs.size(); ++i) { + os << "\targ[" << i << "]=" << node->inputs[i].source->name + << '(' << node->inputs[i].index << ")\n"; + } + } + }); } - CHECK_LT(args.size(), s.arg_users_->size()) << "Too many args, requires " << s.arg_users_->size() - << " provided " << args.size(); +} + +int Symbol::FindDuplicateArgs(std::unordered_map *out) const { + out->clear(); + int max_dup = 1; + this->DFSVisit([out, &max_dup](Node *node) { + if (node->is_variable()) { + auto iter = out->find(node->name); + if (iter == out->end()) { + (*out)[node->name] = 1; + } else { + ++iter->second; + max_dup = std::max(max_dup, iter->second); + } + } + }); + return max_dup; +} + +void Symbol::Compose(const std::vector& args, + const std::string& name) { + CHECK(!head_.source->is_variable()) << "PlaceHolder cannot be composed"; + head_.source->name = name; for (size_t i = 0; i < args.size(); ++i) { - const std::pair& arg_user = (*s.arg_users_)[i]; - arg_user.first->in_symbol_[arg_user.second] = args[i].head_; - CHECK_NE(args[i].index_, -1) << "Argument " << i << " is a tuple, scalar is required"; - arg_user.first->in_index_[arg_user.second] = args[i].index_; + CHECK_NE(args[i].head_.index, -1) + << "Argument " << i << " is a tuple, scalar is required"; + } + // positional arguments requires all arguments for now. + // TODO(bing) consider partial assignments + if (head_.source->is_atomic()) { + // atomic symbol do not have place holder for all the arguments + std::vector req_args = head_.source->sym->ListArguments(); + CHECK_EQ(args.size(), req_args.size()) + << "Incorrect number of arguments, requires " << req_args.size() + << ", provided " << args.size(); + head_.source->inputs.resize(args.size()); + for (size_t i = 0; i < args.size(); ++i) { + head_.source->inputs[i] = args[i].head_; + } + } else { + // find all the place holders + size_t arg_counter = 0; + std::unordered_map replace_map; + std::vector > replace_plan; + // replace map stores the existing replacement plan for arguments node + this->DFSVisit([&arg_counter, &replace_map, &replace_plan, &args](Node *node) { + // visit all the childs, find possible replacement + for (size_t i = 0; i < node->inputs.size(); ++i) { + DataEntry *e = &(node->inputs[i]); + if (e->source->is_variable()) { + const DataEntry *target = nullptr; + auto iter = replace_map.find(e->source.get()); + if (iter == replace_map.end()) { + if (arg_counter < args.size()) { + target = &(args[arg_counter].head_); + replace_map[e->source.get()] = target; + } + ++arg_counter; + } else { + target = iter->second; + } + replace_plan.push_back(std::make_pair(e, target)); + } + } + }); + CHECK_EQ(args.size(), arg_counter) + << "Incorrect number of arguments, requires " << arg_counter + << ", provided " << args.size(); + // now run the replacement + for (const auto& kv : replace_plan) { + *(kv.first) = *(kv.second); + } } - s.arg_users_.reset(); - return s; } -Symbol Symbol::operator () (const std::unordered_map& kwargs) const { - Symbol s = this->Copy(); - if (!s.arg_users_) { // if arg_users_ has not been populated - s.FindArgUsers(); - } - CHECK_LT(kwargs.size(), s.arg_users_->size()) << "Too many args, requires " - << s.arg_users_->size() << " provided " << kwargs.size(); - for (size_t i = 0; i < s.arg_users_->size(); ++i) { - const std::pair& arg_user = (*s.arg_users_)[i]; - const std::string& name = arg_user.first->name_; - if (!(name == "") && kwargs.count(name) != 0) { - const Symbol& bind = kwargs.at(name); - arg_user.first->in_symbol_[arg_user.second] = bind.head_; - CHECK_NE(bind.index_, -1) << "Argument " << name << " is a tuple, scalar is required"; - arg_user.first->in_index_[arg_user.second] = bind.index_; +void Symbol::Compose(const std::unordered_map& kwargs, + const std::string& name) { + CHECK(!head_.source->is_variable()) << "PlaceHolder cannot be composed"; + head_.source->name = name; + for (const auto& kv : kwargs) { + CHECK_NE(kv.second.head_.index, -1) + << "Keyword Argument " << kv.first << " is a tuple, scalar is required"; + } + size_t nmatched = 0; + if (head_.source->is_atomic()) { + // atomic symbol do not have place holder for all the arguments + std::vector req_args = head_.source->sym->ListArguments(); + head_.source->inputs.resize(req_args.size()); + for (size_t i = 0; i < req_args.size(); ++i) { + auto iter = kwargs.find(req_args[i]); + if (iter != kwargs.end()) { + head_.source->inputs[i] = iter->second.head_; + + ++nmatched; + } else { + // create a variable node + // TODO(bing): think of naming convention + if (name.length() == 0) { + head_.source->inputs[i] = DataEntry( + std::make_shared(nullptr, req_args[i]), 0); + } else { + head_.source->inputs[i] = DataEntry( + std::make_shared(nullptr, name + '_' + req_args[i]), 0); + } + } + } + // if things goes wrong recover the old state + if (nmatched != kwargs.size()) { + head_.source->inputs.clear(); + } + } else { + // find all the arguments positions + std::unordered_map dup_args; + int max_dup = this->FindDuplicateArgs(&dup_args); + if (max_dup > 1) { + for (const auto& kv : dup_args) { + CHECK_EQ(kv.second, 1) + << " Argument name=\"" << kv.first << "\" occured in " + << kv.second << " places in the Symbol, " + << "Keyword argument call is not supported because this duplication."; + } + } + CHECK_EQ(max_dup, 1); + std::vector > replace_plan; + std::unordered_set visited; + // replace map stores the existing replacement plan for arguments node + this->DFSVisit([&nmatched, &visited, &kwargs, &replace_plan](Node *node) { + // visit all the childs, find possible replacement + for (size_t i = 0; i < node->inputs.size(); ++i) { + DataEntry *e = &(node->inputs[i]); + if (e->source->is_variable()) { + const DataEntry *target = nullptr; + auto iter = kwargs.find(e->source->name); + if (iter != kwargs.end()) { + target = &(iter->second.head_); + // count how many arguments have been matched. + if (visited.count(e->source.get()) == 0) { + visited.insert(e->source.get()); + ++nmatched; + } + replace_plan.push_back(std::make_pair(e, target)); + } + } + } + }); + if (nmatched == kwargs.size()) { + for (const auto& kv : replace_plan) { + *(kv.first) = *(kv.second); + } + } + } + if (nmatched != kwargs.size()) { + // Error message handling + std::vector req_args = this->ListArguments(); + std::unordered_set keys(req_args.begin(), req_args.end()); + std::ostringstream msg; + msg << "\nCandidate arguments:\n"; + for (size_t i = 0; i < req_args.size(); ++i) { + msg << "\t[" << i << ']' << req_args[i] << '\n'; + } + for (const auto& kv : kwargs) { + CHECK_NE(keys.count(kv.first), 0) + << "Keyword Argument " << kv.first << " not found in arguments." + << msg.str(); } } - s.arg_users_.reset(); - // TODO(linmin): report error if kwargs contains non-existing keys - return s; } Symbol Symbol::operator[] (int index) const { - CHECK_EQ(index_, -1) << "Current symbol can't be indexed because it returns a scalar."; + CHECK_EQ(head_.index, -1) << "Current symbol can't be indexed because it returns a scalar."; + CHECK_GE(index, 0) << "Symbol only accept nonnegative index"; Symbol s = *this; - s.index_ = index; + s.head_.index = index; return s; } -std::vector Symbol::ListArgs() { +std::vector Symbol::ListArguments() const { std::vector ret; - if (!arg_users_) { - FindArgUsers(); - } - std::transform(arg_users_->begin(), arg_users_->end(), std::back_inserter(ret), - [&](const std::pair& n) -> std::string { - return n.first->in_symbol_[n.second]->name_; + if (head_.source->is_atomic()) { + return head_.source->sym->ListArguments(); + } else { + this->DFSVisit([&ret](Node *node) { + if (node->is_variable()) { + ret.push_back(node->name); + } }); - return ret; + return ret; + } } -Symbol Symbol::Create(AtomicSymbol *atomic_symbol) { - Symbol s; - std::vector args = atomic_symbol->DescribeArguments(); - std::vector rets = atomic_symbol->DescribeReturns(); - // set head_ - s.head_ = std::make_shared(atomic_symbol, ""); - // set index_ - s.index_ = rets.size() > 1 ? -1 : 0; - // set head_->in_index_ - s.head_->in_index_ = std::vector(args.size(), 0); - // set head_->in_symbol_ - for (auto name : args) { - s.head_->in_symbol_.push_back(std::make_shared(nullptr, name)); - } - // set head_->out_shape_ - s.head_->out_shape_ = std::vector(rets.size()); - return s; +inline void Symbol::Toposort(const std::vector &heads, + std::vector *ret) { + std::unordered_map out_degree; + std::queue queue; + ret->clear(); + size_t idx = 0; + DFSVisit(heads, + [&out_degree](Node* node) { + for (auto &entry : node->inputs) { + Node *ptr = entry.source.get(); + auto iter = out_degree.find(ptr); + if (iter == out_degree.end()) { + out_degree[ptr] = 0; + } else { + iter->second += 1; + } + } + }); + for (auto &entry : heads) { + queue.push(entry.source.get()); + } + idx = out_degree.size(); + ret->resize(idx); + --idx; + while (queue.size() > 0) { + Node *node = queue.front(); + queue.pop(); + ret->at(idx--) = node; + for (auto it = node->inputs.rbegin(); it != node->inputs.rend(); ++it) { + Node *ptr = it->source.get(); + out_degree[ptr] -= 1; + if (out_degree[ptr] == 0) { + queue.push(ptr); + } + } + } } -Symbol Symbol::Create(const std::string& type_name, - const std::vector >& param) { - const AtomicSymbolEntry *entry = Registry::Find(type_name); - CHECK_NE(entry, NULL) << type_name << " is not a valid Symbol type"; - AtomicSymbol* atomic_symbol = (*entry)(); - for (auto p : param) { - atomic_symbol->SetParam(p.first.c_str(), p.second.c_str()); - } - return Create(atomic_symbol); +bool Symbol::InferShape(std::vector *in_shape, + std::vector *out_shape) { + bool success = true; + StaticGraph graph; + auto input_args = this->ListArguments(); + std::vector tmp_arg = {*this}; + CHECK(in_shape->size() == input_args.size()) << "Input shape should be same to arguments"; + out_shape->clear(); + Convert(tmp_arg, &graph); + for (size_t i = 0; i < in_shape->size(); ++i) { + graph.nodes[graph.in_args_node_id[i]].in_shape.push_back(in_shape->at(i)); + } + for (auto &nd : graph.nodes) { + success &= nd.sym->InferShape(&nd.in_shape, &nd.out_shape); + } + // copy result back + for (size_t i = 0; i < in_shape->size(); ++i) { + in_shape->at(i) = graph.nodes[graph.in_args_node_id[i]].in_shape[0]; + } + for (auto i : graph.return_node_id) { + for (auto sp : graph.nodes[i].out_shape) { + out_shape->push_back(sp); + } + } + return success; } -StaticGraph Symbol::ToStaticGraph() { - StaticGraph graph; - Dfs(this->head_, &graph); - return graph; +std::vector Symbol::ListReturns() const { + return head_.source->sym->ListReturns(); } -CompositeOperator* Symbol::Bind(Context ctx, const std::unordered_map& in) { - StaticGraph graph = this->ToStaticGraph(); - return NULL; - // TODO(bing): pass the graph and in to initlialize a composite op. +Symbol Symbol::Create(AtomicSymbol *atomic_symbol) { + // use special representation for atomic symbol + Symbol s; + s.head_ = DataEntry(std::make_shared(atomic_symbol, ""), + atomic_symbol->ListReturns().size() > 1 ? -1 : 0); + return s; } -void Symbol::Dfs(const std::shared_ptr node, StaticGraph *graph) { - int id = graph->FindNodeByName(node->name_, node->sym_); - for (size_t i = 0; i < node->in_symbol_.size(); ++i) { - std::shared_ptr parent = node->in_symbol_[i]; - int parent_id = graph->FindNodeByName(parent->name_, parent->sym_); - graph->connected_graph[parent_id].push_back(id); - graph->output_index[parent_id].push_back(node->in_index_[i]); - if (parent->sym_) { - Dfs(parent, graph); +void Symbol::Convert(const std::vector &heads, StaticGraph *out_graph) { + // TODO(bing): Check unique name + std::vector nodes; + std::unordered_map node_id_dic; + std::vector arg(heads.size()); + for (size_t i = 0; i < heads.size(); ++i) { + arg[i] = heads[i].head_; + } + Toposort(arg, &nodes); + out_graph->nodes.resize(nodes.size()); + // set up dict + for (size_t i = 0; i < nodes.size(); ++i) { + node_id_dic[nodes[i]] = i; + } + // copy + for (size_t i = 0; i < nodes.size(); ++i) { + out_graph->name_id_map[nodes[i]->name] = i; + if (!nodes[i]->is_variable()) { + out_graph->nodes[i].sym.reset(nodes[i]->sym->Copy()); } + out_graph->nodes[i].name = nodes[i]->name; + for (auto &entry : nodes[i]->inputs) { + out_graph->nodes[i].inputs_index.push_back(node_id_dic[entry.source.get()]); + out_graph->nodes[node_id_dic[entry.source.get()]].outputs_index.push_back(i); + } + } + // set input map + for (auto const &head : heads) { + auto input_args = head.ListArguments(); + out_graph->in_args_node_id.resize(input_args.size()); + for (size_t i = 0; i < input_args.size(); ++i) { + out_graph->in_args_node_id[i] = out_graph->name_id_map[input_args[i]]; + } + } + // set output map + out_graph->return_node_id.resize(heads.size()); + for (size_t i = 0; i < heads.size(); ++i) { + out_graph->return_node_id[i] = out_graph->name_id_map[heads[i].head_.source->name]; } }