Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

new symbol interface #9

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions include/mxnet/atomic_symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ class AtomicSymbol {
*/
virtual ~AtomicSymbol() {}
/*! \brief get the descriptions of inputs for this symbol */
virtual std::vector<std::string> DescribeArguments() const {
virtual std::vector<std::string> ListArguments() const {
// default implementation returns "data"
return std::vector<std::string>(1, std::string("data"));
}
/*! \brief get the descriptions of outputs for this symbol */
virtual std::vector<std::string> DescribeReturns() const {
virtual std::vector<std::string> ListReturns() const {
// default implementation returns "output"
return std::vector<std::string>(1, std::string("output"));
}
Expand Down Expand Up @@ -77,6 +77,13 @@ class AtomicSymbol {
*/
virtual std::string TypeString() const = 0;
friend class Symbol;

/*!
* \brief create atomic symbol by type name
* \param type_name the type string of the AtomicSymbol
* \return a new constructed AtomicSymbol
*/
static AtomicSymbol *Create(const char* type_name);
};

} // namespace mxnet
Expand Down
89 changes: 63 additions & 26 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,21 @@ MXNET_DLL int MXFuncInvoke(FunctionHandle fun,
// Part 3: symbolic configuration generation
//--------------------------------------------
/*!
* \brief create symbol from config
* \param cfg configuration string
* \param out created symbol handle
* \brief list all the available AtomicSymbolEntry
* \param out_size the size of returned array
* \param out_array the output AtomicSymbolCreator array
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateFromConfig(const char *cfg,
SymbolHandle *out);
MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
AtomicSymbolCreator **out_array);
/*!
* \brief Get the name of AtomicSymbol.
* \param creator the AtomicSymbolCreator
* \param out the returned name of the creator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
const char **out);
/*!
* \brief create Symbol by wrapping AtomicSymbol
* \param creator the AtomicSymbolCreator
Expand All @@ -231,50 +239,79 @@ MXNET_DLL int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator,
const char **vals,
SymbolHandle *out);
/*!
* \brief free the symbol handle
* \brief Create a Variable Symbol.
* \param name name of the variable
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateVariable(const char *name, SymbolHandle *out);
/*!
* \brief Create symbol from config.
* \param cfg configuration string
* \param out created symbol handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCreateFromConfig(const char *cfg,
SymbolHandle *out);
/*!
* \brief Free the symbol handle.
* \param symbol the symbol
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolFree(SymbolHandle symbol);
/*!
* \brief list all the available AtomicSymbolEntry
* \param out_size the size of returned array
* \param out_array the output AtomicSymbolCreator array
* \brief Copy the symbol to another handle
* \param symbol the source symbol
* \param out used to hold the result of copy
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
AtomicSymbolCreator **out_array);
MXNET_DLL int MXSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
/*!
* \brief get the singleton Symbol of the AtomicSymbol if any
* \param creator the AtomicSymbolCreator
* \param out the returned singleton Symbol of the AtomicSymbol the creator stands for
* \brief Print the content of symbol, used for debug.
* \param symbol the symbol
* \param out_str pointer to hold the output string of the printing.
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetSingleton(AtomicSymbolCreator creator,
SymbolHandle *out);
MXNET_DLL int MXSymbolPrint(SymbolHandle symbol, const char **out_str);
/*!
* \brief get the singleton Symbol of the AtomicSymbol if any
* \param creator the AtomicSymbolCreator
* \param out the returned name of the creator
* \brief List arguments in the symbol.
* \param symbol the symbol
* \param out_size output size
* \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
const char **out);
MXNET_DLL int MXSymbolListArguments(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);
/*!
* \brief compose the symbol on other symbol
* \brief List returns in the symbol.
* \param Symbol the symbol
* \param out_size output size
* \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolListReturns(SymbolHandle symbol,
mx_uint *out_size,
const char ***out_str_array);
/*!
* \brief Compose the symbol on other symbols.
*
* This function will change the sym hanlde.
* To achieve function apply behavior, copy the symbol first
* before apply.
*
* \param sym the symbol to apply
* \pram name the name of symbol
* \param num_args number of arguments
* \param keys the key of keyword args (optional)
* \param args arguments to sym
* \param out the resulting symbol
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolCompose(SymbolHandle sym,
const char *name,
mx_uint num_args,
const char** keys,
SymbolHandle* args,
SymbolHandle* out);

SymbolHandle* args);
//--------------------------------------------
// Part 4: operator interface on NArray
//--------------------------------------------
Expand Down
17 changes: 1 addition & 16 deletions include/mxnet/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,31 +229,16 @@ struct AtomicSymbolEntry {
std::string name;
/*! \brief function body to create AtomicSymbol */
Creator body;
/*! \brief singleton is created when no param is needed for the AtomicSymbol */
Symbol *singleton_symbol;
/*! \brief constructor */
explicit AtomicSymbolEntry(const std::string& name)
: use_param(true), name(name), body(NULL), singleton_symbol(NULL) {}
/*!
* \brief set if param is needed by this AtomicSymbol
*/
inline AtomicSymbolEntry &set_use_param(bool use_param) {
this->use_param = use_param;
return *this;
}
: use_param(true), name(name), body(NULL) {}
/*!
* \brief set the function body
*/
inline AtomicSymbolEntry &set_body(Creator body) {
this->body = body;
return *this;
}
/*!
* \brief return the singleton symbol
*/
Symbol *GetSingletonSymbol();
/*! \brief destructor */
~AtomicSymbolEntry();
/*!
* \brief invoke the function
* \return the created AtomicSymbol
Expand Down
43 changes: 16 additions & 27 deletions include/mxnet/static_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,28 @@ struct StaticGraph {
/*! \brief Node in static graph */
struct StaticNode {
/*! \brief wrapped atomic symbol */
AtomicSymbol* sym_;
std::unique_ptr<AtomicSymbol> sym;
/*! \brief name of the node */
std::string name_;
std::string name;
/*! \brief index of output from the source. */
int index;
/*! \brief output shape for node */
std::vector<TShape> in_shape;
/*! \brief output shape for node */
std::vector<TShape> out_shape;
/*! \brief input id for each node */
std::vector<int> inputs_index;
/*! \brief output id for each node */
std::vector<int> outputs_index;
};
/*! \brief head node (need input from outside) */
std::vector<int> in_args_node_id;
/*! \brief tail node (generate data to outside) */
std::vector<int> return_node_id;
/*! \brief node name to id dictionary */
std::unordered_map<std::string, int> name_id_map;
/*! \brief all nodes in the graph */
std::vector<StaticNode> nodes;
/*! \brief output id for each node */
std::vector<std::vector<int> > output_index;
/*! \brief connected graph for each node */
std::vector<std::vector<int> > connected_graph;
/*! \brief find node by using name
* \param name node name
* \param sym symbol need to be copied into node
* \return node id
*/
int FindNodeByName(const std::string& name, const AtomicSymbol* sym) {
int id = 0;
if (name_id_map.find(name) == name_id_map.end()) {
name_id_map[name] = name_id_map.size();
StaticNode static_node;
static_node.sym_ = sym->Copy();
static_node.name_ = name;
nodes.push_back(static_node);
output_index.push_back(std::vector<int>());
connected_graph.push_back(std::vector<int>());
id = name_id_map.size();
} else {
id = name_id_map[name];
}
return id;
}
};
} // namespace mxnet
#endif // MXNET_STATIC_GRAPH_H_
Loading