From ada652b5c60426de984074dac1994e9576feafa1 Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 17 Jul 2015 18:06:29 +0800 Subject: [PATCH 1/6] docs --- include/mxnet/symbol.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index a623cfa1502e..4afc89b700e2 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -32,15 +32,15 @@ class Symbol { * with input symbols. */ struct Node { - /*! wrapped atomic symbol */ + /*! \brief wrapped atomic symbol */ AtomicSymbol* sym_; - /*! name of the node */ + /*! \brief name of the node */ std::string name_; - /*! inputs to this node */ + /*! \brief inputs to this node */ std::vector > in_symbol_; - /*! index of the inputs if the inputs are tuple */ + /*! \brief index of the inputs if the inputs are tuple */ std::vector in_index_; - /*! the output shape of the wrapped symbol */ + /*! \brief the output shape of the wrapped symbol */ std::vector out_shape_; /*! * \brief constructor From 8edf81905d9e4c463059d882dcde3312620fceb1 Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 17 Jul 2015 18:32:41 +0800 Subject: [PATCH 2/6] add type string to atomic symbol --- include/mxnet/atomic_symbol.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/include/mxnet/atomic_symbol.h b/include/mxnet/atomic_symbol.h index 086bba9c6bae..8d4019f87d62 100644 --- a/include/mxnet/atomic_symbol.h +++ b/include/mxnet/atomic_symbol.h @@ -67,6 +67,11 @@ class AtomicSymbol { * Calling bind from the Symbol wrapper would generate a NArrayOperator. */ virtual Operator* Bind(Context ctx) const = 0; + /*! + * \brief return the type string of the atomic symbol + * subclasses override this function. + */ + virtual std::string TypeString() const = 0; friend class Symbol; }; From 6214cd9e2c9a30e8b8c381f2f0a1c81143e2719f Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 17 Jul 2015 19:20:20 +0800 Subject: [PATCH 3/6] register AtomicSymbol --- Makefile | 3 +- api/mxnet_api.cc | 55 +++++-------- api/mxnet_api.h | 54 +++++-------- api/python/mxnet/symbol_creator.py | 23 +++--- include/mxnet/api_registry.h | 65 --------------- include/mxnet/atomic_symbol_registry.h | 108 +++++++++++++++++++++++++ include/mxnet/symbol.h | 19 +++-- src/api_registry.cc | 22 ----- src/symbol/atomic_symbol_registry.cc | 18 +++++ 9 files changed, 189 insertions(+), 178 deletions(-) create mode 100644 include/mxnet/atomic_symbol_registry.h create mode 100644 src/symbol/atomic_symbol_registry.cc diff --git a/Makefile b/Makefile index 3af4ad9f14d3..0ad9dce5b7aa 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,7 @@ endif BIN = test/api_registry_test OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o symbol.o operator.o +OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o symbol.o operator.o atomic_symbol_registry.o CUOBJ = SLIB = api/libmxnet.so ALIB = api/libmxnet.a @@ -85,6 +85,7 @@ static_operator.o: src/static_operator/static_operator.cc static_operator_cpu.o: src/static_operator/static_operator_cpu.cc static_operator_gpu.o: src/static_operator/static_operator_gpu.cu symbol.o: src/symbol/symbol.cc +atomic_symbol_registry.o: src/symbol/atomic_symbol_registry.cc api_registry.o: src/api_registry.cc mxnet_api.o: api/mxnet_api.cc operator.o: src/operator/operator.cc diff --git a/api/mxnet_api.cc b/api/mxnet_api.cc index 7e351e52cb03..974c5dbfbb4c 100644 --- a/api/mxnet_api.cc +++ b/api/mxnet_api.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include "./mxnet_api.h" @@ -243,57 +244,43 @@ int MXFuncInvoke(FunctionHandle fun, auto *f = static_cast(fun); (*f)((NArray**)(use_vars), // NOLINT(*) scalar_args, - (NArray**)(mutate_vars)); // NOLINT(*) + (NArray**)(mutate_vars)); // NOLINT(*) API_END(); } -int MXSymFree(SymbolHandle sym) { +int MXSymCreate(const char *type_str, + int num_param, + const char** keys, + const char** vals, + SymbolHandle* out) { API_BEGIN(); - delete static_cast(sym); + CCreateSymbol(type_str, num_param, keys, vals, (Symbol**)out); // NOLINT(*) API_END(); } -int MXSymCreatorDescribe(SymbolCreatorHandle sym_creator, - mx_uint *use_param) { +int MXSymFree(SymbolHandle sym) { API_BEGIN(); - auto *sc = static_cast(sym_creator); - *use_param = sc->use_param ? 1 : 0; + delete static_cast(sym); API_END(); } -int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator, - int count, - const char** keys, - const char** vals, - SymbolHandle* out) { +int MXSymDescribe(const char *type_str, + mx_uint *use_param) { API_BEGIN(); - const SymbolCreatorRegistry::Entry *sc = - static_cast(sym_creator); - sc->body(count, keys, vals, (Symbol**)(out)); // NOLINT(*) + *use_param = AtomicSymbolRegistry::Find(type_str)->use_param ? 1 : 0; API_END(); } -int MXListSymCreators(mx_uint *out_size, - SymbolCreatorHandle **out_array) { +int MXListSyms(mx_uint *out_size, + const char ***out_array) { API_BEGIN(); - auto &vec = SymbolCreatorRegistry::List(); + auto &vec = AtomicSymbolRegistry::List(); *out_size = static_cast(vec.size()); - *out_array = (SymbolCreatorHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*) - API_END(); -} - -int MXGetSymCreator(const char *name, - SymbolCreatorHandle *out) { - API_BEGIN(); - *out = SymbolCreatorRegistry::Find(name); - API_END(); -} - -int MXSymCreatorGetName(SymbolCreatorHandle sym_creator, - const char **out_name) { - API_BEGIN(); - auto *f = static_cast(sym_creator); - *out_name = f->name.c_str(); + std::vector type_strs; + for (auto entry : vec) { + type_strs.push_back(entry->type_str.c_str()); + } + *out_array = (const char**)(dmlc::BeginPtr(type_strs)); // NOLINT(*) API_END(); } diff --git a/api/mxnet_api.h b/api/mxnet_api.h index fb4b8710e2e6..4c357593abce 100644 --- a/api/mxnet_api.h +++ b/api/mxnet_api.h @@ -214,6 +214,20 @@ MXNET_DLL int MXFuncInvoke(FunctionHandle fun, */ MXNET_DLL int MXSymCreateFromConfig(const char *cfg, SymbolHandle *out); +/*! + * \brief invoke registered symbol creator through its handle. + * \param type_str the type of the AtomicSymbol + * \param num_param the number of the key value pairs in the param. + * \param keys an array of c str. + * \param vals the corresponding values of the keys. + * \param out pointer to the created symbol handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymCreate(const char *type_str, + int num_param, + const char** keys, + const char** vals, + SymbolHandle* out); /*! * \brief free the symbol handle * \param sym the symbol @@ -222,26 +236,12 @@ MXNET_DLL int MXSymCreateFromConfig(const char *cfg, MXNET_DLL int MXSymFree(SymbolHandle sym); /*! * \brief query if the symbol creator needs param. - * \param sym_creator the symbol creator handle + * \param type_str the type of the AtomicSymbol * \param use_param describe if the symbol creator requires param * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymCreatorDescribe(SymbolCreatorHandle sym_creator, - mx_uint *use_param); -/*! - * \brief invoke registered symbol creator through its handle. - * \param sym_creator pointer to the symbolcreator function. - * \param count the number of the key value pairs in the param. - * \param keys an array of c str. - * \param vals the corresponding values of the keys. - * \param out pointer to the created symbol handle - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator, - int count, - const char** keys, - const char** vals, - SymbolHandle* out); +MXNET_DLL int MXSymDescribe(const char *type_str, + mx_uint *use_param); /*! * \brief list all the available sym_creator * most user can use it to list all the needed sym_creators @@ -249,24 +249,8 @@ MXNET_DLL int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator, * \param out_array the output sym_creators * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXListSymCreators(mx_uint *out_size, - SymbolCreatorHandle **out_array); -/*! - * \brief get the sym_creator by name - * \param name the name of the sym_creator - * \param out the corresponding sym_creator - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXGetSymCreator(const char *name, - SymbolCreatorHandle *out); -/*! - * \brief get the name of sym_creator handle - * \param sym_creator the sym_creator handle - * \param out_name the name of the sym_creator - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXSymCreatorGetName(SymbolCreatorHandle sym_creator, - const char **out_name); +MXNET_DLL int MXListSyms(mx_uint *out_size, + const char ***out_array); /*! * \brief compose the symbol on other symbol * \param sym the symbol to apply diff --git a/api/python/mxnet/symbol_creator.py b/api/python/mxnet/symbol_creator.py index e8a49149ec35..ce999e32b267 100644 --- a/api/python/mxnet/symbol_creator.py +++ b/api/python/mxnet/symbol_creator.py @@ -12,7 +12,7 @@ class _SymbolCreator(object): """SymbolCreator is a function that takes Param and return symbol""" - def __init__(self, handle, name): + def __init__(self, name): """Initialize the function with handle Parameters @@ -23,11 +23,10 @@ def __init__(self, handle, name): name : string the name of the function """ - self.handle = handle self.name = name use_param = mx_uint() - check_call(_LIB.MXSymCreatorDescribe( - self.handle, + check_call(_LIB.MXSymDescribe( + c_str(self.name), ctypes.byref(use_param))) self.use_param = use_param.value @@ -45,8 +44,8 @@ def __call__(self, **kwargs): keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()]) sym_handle = SymbolHandle() - check_call(_LIB.MXSymCreatorInvoke( - self.handle, + check_call(_LIB.MXSymCreate( + c_str(self.name), mx_uint(len(kwargs)), keys, vals, @@ -56,14 +55,12 @@ def __call__(self, **kwargs): class _SymbolCreatorRegistry(object): """Function Registry""" def __init__(self): - plist = ctypes.POINTER(ctypes.c_void_p)() + plist = ctypes.POINTER(ctypes.c_char_p)() size = ctypes.c_uint() - check_call(_LIB.MXListSymCreators(ctypes.byref(size), - ctypes.byref(plist))) + check_call(_LIB.MXListSyms(ctypes.byref(size), + ctypes.byref(plist))) hmap = {} for i in range(size.value): - hdl = plist[i] - name = ctypes.c_char_p() - check_call(_LIB.MXSymCreatorGetName(hdl, ctypes.byref(name))) - hmap[name.value] = _SymbolCreator(hdl, name.value) + name = plist[i] + hmap[name.value] = _SymbolCreator(name.value) self.__dict__.update(hmap) diff --git a/include/mxnet/api_registry.h b/include/mxnet/api_registry.h index 91083c2ea11d..b2806932dabf 100644 --- a/include/mxnet/api_registry.h +++ b/include/mxnet/api_registry.h @@ -212,70 +212,5 @@ class FunctionRegistry { static auto __ ## name ## _narray_fun__ = \ ::mxnet::FunctionRegistry::Get()->Register("" # name) -/*! \brief registry of symbol creator */ -class SymbolCreatorRegistry { - public: - /*! \brief SymbolCreator is a function pointer */ - typedef void(*SymbolCreator)(int count, const char**, const char**, Symbol**); - /*! \return get a singleton */ - static SymbolCreatorRegistry *Get(); - /*! \brief keep the SymbolCreator function and its meta information */ - struct Entry { - /*! \brief the name of the symbol creator */ - std::string name; - /*! \brief the body of the function */ - SymbolCreator body; - /*! \brief if the creator requires params to construct */ - bool use_param; - /*! \brief constructor */ - explicit Entry(const std::string& name) : name(name), body(nullptr), use_param(true) {} - /*! \brief setter of body */ - inline Entry& set_body(SymbolCreator sc) { body = sc; return *this; } - /*! \brief setter of use_param */ - inline Entry& set_use_param(bool up) { use_param = up; return *this; } - }; - /*! - * \brief register a name symbol under name - * \param name name of the function - * \return ref to the registered entry, used to set properties - */ - Entry &Register(const std::string& name); - /*! \return list of functions in the registry */ - inline static const std::vector &List() { - return Get()->fun_list_; - } - /*! - * \brief find an symbolcreator entry with corresponding name - * \param name name of the symbolcreator - * \return the corresponding symbolcreator, can be NULL - */ - inline static const Entry *Find(const std::string &name) { - auto &fmap = Get()->fmap_; - auto p = fmap.find(name); - if (p != fmap.end()) { - return p->second; - } else { - return nullptr; - } - } - - private: - /*! \brief list of functions */ - std::vector fun_list_; - /*! \brief map of name->function */ - std::map fmap_; - /*! \brief constructor */ - SymbolCreatorRegistry() {} - /*! \brief destructor */ - ~SymbolCreatorRegistry(); -}; - -/*! - * \brief macro to register symbol creator - */ -#define REGISTER_SYMBOL_CREATOR(name) \ - static auto __ ## name ## _symbol_creator__ = \ - ::mxnet::SymbolCreatorRegistry::Get()->Register("" # name) - } // namespace mxnet #endif // MXNET_API_REGISTRY_H_ diff --git a/include/mxnet/atomic_symbol_registry.h b/include/mxnet/atomic_symbol_registry.h new file mode 100644 index 000000000000..8028bd2cfefa --- /dev/null +++ b/include/mxnet/atomic_symbol_registry.h @@ -0,0 +1,108 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file atomic_symbol_registry.h + * \brief atomic symbol registry interface of mxnet + */ +#ifndef MXNET_ATOMIC_SYMBOL_REGISTRY_H_ +#define MXNET_ATOMIC_SYMBOL_REGISTRY_H_ +#include +// check c++11 +#if DMLC_USE_CXX11 == 0 +#error "cxx11 was required for symbol registry module" +#endif +#include +#include +#include +#include +#include +#include "./base.h" + +namespace mxnet { + +/*! + * \brief Register AtomicSymbol + */ +class AtomicSymbolRegistry { + public: + /*! \return get a singleton */ + static AtomicSymbolRegistry *Get(); + /*! \brief registered entry */ + struct Entry { + /*! \brief constructor */ + explicit Entry(const std::string& type_str) : type_str(type_str), + use_param(true), + body(nullptr) {} + /*! \brief type string of the entry */ + std::string type_str; + /*! \brief whether param is required */ + bool use_param; + /*! \brief function body to create AtomicSymbol */ + std::function body; + /*! + * \brief set if param is needed by this AtomicSymbol + */ + Entry &set_use_param(bool use_param) { + this->use_param = use_param; + return *this; + } + /*! + * \brief set the function body + */ + Entry &set_body(const std::function& body) { + this->body = body; + return *this; + } + }; + /*! + * \brief register the maker function with name + * \return the type string of the AtomicSymbol + */ + template + Entry &Register() { + AtomicType instance; + std::string type_str = instance.TypeString(); + Entry *e = new Entry(type_str); + fmap_[type_str] = e; + fun_list_.push_back(e); + e->set_body([]()->AtomicSymbol* { + return new AtomicType; + }); + return *e; + } + /*! + * \brief find the entry by type string + * \param type_str the type string of the AtomicSymbol + * \return the corresponding entry + */ + inline static const Entry* Find(const std::string& type_str) { + auto &fmap = Get()->fmap_; + auto p = fmap.find(type_str); + if (p != fmap.end()) { + return p->second; + } else { + return nullptr; + } + } + /*! \brief list all the AtomicSymbols */ + inline static const std::vector &List() { + return Get()->fun_list_; + } + /*! \brief make a atomicsymbol according to the typename */ + inline static AtomicSymbol* Make(const std::string& name) { + return Get()->fmap_[name]->body(); + } + + protected: + /*! \brief list of functions */ + std::vector fun_list_; + /*! \brief map of name->function */ + std::unordered_map fmap_; +}; + +/*! \brief macro to register AtomicSymbol to AtomicSymbolFactory */ +#define REGISTER_ATOMIC_SYMBOL(AtomicType) \ + static auto __## AtomicType ## _entry__ = \ + AtomicSymbolRegistry::Get()->Register() + +} // namespace mxnet +#endif // MXNET_ATOMIC_SYMBOL_REGISTRY_H_ diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index 4afc89b700e2..f024774c0e3b 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -7,6 +7,7 @@ #define MXNET_SYMBOL_H_ #include +#include #include #include #include @@ -95,11 +96,12 @@ class Symbol { virtual std::vector ListArgs(); /*! * \brief create atomic symbol wrapped in symbol + * \param type_str the type string of the AtomicSymbol * \param param the parameter stored as key value pairs * \return the constructed Symbol */ - template - static Symbol CreateSymbol(const std::vector >& param) { + static Symbol CreateSymbol(const std::string& type_str, + const std::vector >& param) { Symbol* s; std::vector keys(param.size()); std::vector vals(param.size()); @@ -107,26 +109,27 @@ class Symbol { keys.push_back(p.first.c_str()); vals.push_back(p.second.c_str()); } - CreateSymbol(param.size(), &keys[0], &vals[0], &s); + CCreateSymbol(type_str.c_str(), param.size(), &keys[0], &vals[0], &s); Symbol ret = *s; delete s; return ret; } /*! * \brief c api for CreateSymbol, this can be registered with SymbolCreatorRegistry + * \param type_str the type string of the AtomicSymbol * \param num_param the number of params * \param keys the key for the params * \param vals values of the params * \param out stores the returning symbol */ - template - friend void CreateSymbol(int num_param, const char** keys, const char** vals, Symbol** out); + friend void CCreateSymbol(const char* type_str, int num_param, const char** keys, + const char** vals, Symbol** out); }; -template -void CreateSymbol(int num_param, const char** keys, const char** vals, Symbol** out) { +void CCreateSymbol(const char* type_str, int num_param, const char** keys, const char** vals, + Symbol** out) { Symbol* s = new Symbol; - Atomic* atom = new Atomic; + AtomicSymbol* atom = AtomicSymbolRegistry::Make(type_str); for (int i = 0; i < num_param; ++i) { atom->SetParam(keys[i], vals[i]); } diff --git a/src/api_registry.cc b/src/api_registry.cc index c029502287e9..0a6423441bbe 100644 --- a/src/api_registry.cc +++ b/src/api_registry.cc @@ -29,26 +29,4 @@ FunctionRegistry *FunctionRegistry::Get() { return &instance; } -// SymbolCreatorRegistry - -SymbolCreatorRegistry::Entry& -SymbolCreatorRegistry::Register(const std::string& name) { - CHECK_EQ(fmap_.count(name), 0); - Entry *e = new Entry(name); - fmap_[name] = e; - fun_list_.push_back(e); - return *e; -} - -SymbolCreatorRegistry::~SymbolCreatorRegistry() { - for (auto p = fmap_.begin(); p != fmap_.end(); ++p) { - delete p->second; - } -} - -SymbolCreatorRegistry *SymbolCreatorRegistry::Get() { - static SymbolCreatorRegistry instance; - return &instance; -} - } // namespace mxnet diff --git a/src/symbol/atomic_symbol_registry.cc b/src/symbol/atomic_symbol_registry.cc new file mode 100644 index 000000000000..a9fa8f9b791a --- /dev/null +++ b/src/symbol/atomic_symbol_registry.cc @@ -0,0 +1,18 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file symbol_registry.cc + * \brief symbol_registry of mxnet + */ +#include +#include +#include +#include + +namespace mxnet { + +AtomicSymbolRegistry *AtomicSymbolRegistry::Get() { + static AtomicSymbolRegistry instance; + return &instance; +} + +} // namespace mxnet From f137132e805e94664e1046a459730a97973df26a Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 17 Jul 2015 19:21:24 +0800 Subject: [PATCH 4/6] add virtual destructor --- include/mxnet/symbol.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index f024774c0e3b..ad7b303e0e77 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -62,6 +62,10 @@ class Symbol { void FindArgUsers(); public: + /*! + * \brief declare virtual destructor in case it is subclassed. + */ + virtual ~Symbol() {} /*! * \brief bind to device and returns an NArrayOperator. * \param ctx context of the operator From e3de4dfae47a8253c23ff1d43677c463b0f5f684 Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 17 Jul 2015 19:33:16 +0800 Subject: [PATCH 5/6] move definition to cc file --- include/mxnet/symbol.h | 24 ------------------------ src/symbol/symbol.cc | 24 ++++++++++++++++++++++++ 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index ad7b303e0e77..a207158b8f27 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -130,29 +130,5 @@ class Symbol { const char** vals, Symbol** out); }; -void CCreateSymbol(const char* type_str, int num_param, const char** keys, const char** vals, - Symbol** out) { - Symbol* s = new Symbol; - AtomicSymbol* atom = AtomicSymbolRegistry::Make(type_str); - for (int i = 0; i < num_param; ++i) { - atom->SetParam(keys[i], vals[i]); - } - std::vector args = atom->DescribeArguments(); - std::vector rets = atom->DescribeReturns(); - // set head_ - s->head_ = std::make_shared(atom, ""); - // set index_ - s->index_ = rets.size() > 1 ? -1 : 0; - // set head_->in_index_ - s->head_->in_index_ = std::vector(args.size(), 0); - // set head_->in_symbol_ - for (auto name : args) { - s->head_->in_symbol_.push_back(std::make_shared(nullptr, name)); - } - // set head_->out_shape_ - s->head_->out_shape_ = std::vector(rets.size()); - *out = s; -} - } // namespace mxnet #endif // MXNET_SYMBOL_H_ diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index d44f63b52d95..a8650df3c29c 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -138,4 +138,28 @@ std::vector Symbol::ListArgs() { return ret; } +void CCreateSymbol(const char* type_str, int num_param, const char** keys, const char** vals, + Symbol** out) { + Symbol* s = new Symbol; + AtomicSymbol* atom = AtomicSymbolRegistry::Make(type_str); + for (int i = 0; i < num_param; ++i) { + atom->SetParam(keys[i], vals[i]); + } + std::vector args = atom->DescribeArguments(); + std::vector rets = atom->DescribeReturns(); + // set head_ + s->head_ = std::make_shared(atom, ""); + // set index_ + s->index_ = rets.size() > 1 ? -1 : 0; + // set head_->in_index_ + s->head_->in_index_ = std::vector(args.size(), 0); + // set head_->in_symbol_ + for (auto name : args) { + s->head_->in_symbol_.push_back(std::make_shared(nullptr, name)); + } + // set head_->out_shape_ + s->head_->out_shape_ = std::vector(rets.size()); + *out = s; +} + } // namespace mxnet From 6325cc406279f76a8cfc87c921eb8634634e1dba Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 17 Jul 2015 19:33:43 +0800 Subject: [PATCH 6/6] fix doc --- include/mxnet/operator.h | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 03e90173fe9e..4dbcea798cdd 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -27,6 +27,11 @@ namespace mxnet { */ class Operator { public: + /*! + * \brief construct Operator from StaticOperator and Context + * \param op StaticOperator to wrap + * \param ctx Context of the Operator + */ Operator(StaticOperator* op, Context ctx); /*! * \brief get types of input argument of this oeprator @@ -61,6 +66,10 @@ class Operator { virtual void InferShape(std::vector *in_shape, std::vector *out_shape); + /*! + * \brief set the context of the Operator + * \param ctx the context to be set to + */ virtual void SetContext(Context ctx); /*! * \brief perform a forward operation of operator, save the output to TBlob