From 1275470af75f843e56f83c5a4961a41ee4432095 Mon Sep 17 00:00:00 2001 From: linmin Date: Sat, 18 Jul 2015 15:13:25 +0800 Subject: [PATCH 1/3] move api_registry to registry --- Makefile | 4 +- api/mxnet_api.cc | 12 +- include/mxnet/api_registry.h | 216 ---------------------------------- include/mxnet/registry.h | 222 +++++++++++++++++++++++++++++++++++ src/api_registry.cc | 32 ----- src/narray/narray.cc | 2 +- src/registry.cc | 32 +++++ test/api_registry_test.cc | 4 +- 8 files changed, 265 insertions(+), 259 deletions(-) delete mode 100644 include/mxnet/api_registry.h create mode 100644 include/mxnet/registry.h delete mode 100644 src/api_registry.cc create mode 100644 src/registry.cc diff --git a/Makefile b/Makefile index 0ad9dce5b7aa..ab0f7a6b2f9d 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,7 @@ endif BIN = test/api_registry_test OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o symbol.o operator.o atomic_symbol_registry.o +OBJCXX11 = engine.o narray.o mxnet_api.o registry.o symbol.o operator.o atomic_symbol_registry.o CUOBJ = SLIB = api/libmxnet.so ALIB = api/libmxnet.a @@ -86,7 +86,7 @@ static_operator_cpu.o: src/static_operator/static_operator_cpu.cc static_operator_gpu.o: src/static_operator/static_operator_gpu.cu symbol.o: src/symbol/symbol.cc atomic_symbol_registry.o: src/symbol/atomic_symbol_registry.cc -api_registry.o: src/api_registry.cc +registry.o: src/registry.cc mxnet_api.o: api/mxnet_api.cc operator.o: src/operator/operator.cc diff --git a/api/mxnet_api.cc b/api/mxnet_api.cc index 974c5dbfbb4c..b619b3dc3fe7 100644 --- a/api/mxnet_api.cc +++ b/api/mxnet_api.cc @@ -8,7 +8,7 @@ #include #include #include -#include +#include #include #include "./mxnet_api.h" @@ -201,7 +201,7 @@ int MXNArrayGetContext(NArrayHandle handle, int MXListFunctions(mx_uint *out_size, FunctionHandle **out_array) { API_BEGIN(); - auto &vec = FunctionRegistry::List(); + auto &vec = Registry::List(); *out_size = static_cast(vec.size()); *out_array = (FunctionHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); @@ -210,14 +210,14 @@ int MXListFunctions(mx_uint *out_size, int MXGetFunction(const char *name, FunctionHandle *out) { API_BEGIN(); - *out = FunctionRegistry::Find(name); + *out = Registry::Find(name); API_END(); } int MXFuncGetName(FunctionHandle fun, const char **out_name) { API_BEGIN(); - auto *f = static_cast(fun); + auto *f = static_cast(fun); *out_name = f->name.c_str(); API_END(); } @@ -228,7 +228,7 @@ int MXFuncDescribe(FunctionHandle fun, mx_uint *num_mutate_vars, int *type_mask) { API_BEGIN(); - auto *f = static_cast(fun); + auto *f = static_cast(fun); *num_use_vars = f->num_use_vars; *num_scalars = f->num_scalars; *num_mutate_vars = f->num_mutate_vars; @@ -241,7 +241,7 @@ int MXFuncInvoke(FunctionHandle fun, mx_float *scalar_args, NArrayHandle *mutate_vars) { API_BEGIN(); - auto *f = static_cast(fun); + auto *f = static_cast(fun); (*f)((NArray**)(use_vars), // NOLINT(*) scalar_args, (NArray**)(mutate_vars)); // NOLINT(*) diff --git a/include/mxnet/api_registry.h b/include/mxnet/api_registry.h deleted file mode 100644 index b2806932dabf..000000000000 --- a/include/mxnet/api_registry.h +++ /dev/null @@ -1,216 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file api_registry.h - * \brief api registry that registers functions - * for C API module and possiblity other modules - */ -#ifndef MXNET_API_REGISTRY_H_ -#define MXNET_API_REGISTRY_H_ -#include -// check c++11 -#if DMLC_USE_CXX11 == 0 -#error "cxx11 was required for api registry module" -#endif -#include -#include -#include -#include -#include "./base.h" -#include "./narray.h" -#include "./symbol.h" - -namespace mxnet { - -/*! \brief mask information on how functions can be exposed */ -enum FunctionTypeMask { - /*! \brief all the use_vars should go before scalar */ - kNArrayArgBeforeScalar = 1, - /*! \brief all the scalar should go before use_vars */ - kScalarArgBeforeNArray = 1 << 1, - /*! - * \brief whether this function allows the handles in the target to - * be empty NArray that are not yet initialized, and will initialize - * them when the function is invoked. - * - * most function should support this, except copy between different - * devices, which requires the NArray to be pre-initialized with context - */ - kAcceptEmptyMutateTarget = 1 << 2 -}; - -/*! \brief registry of NArray functions */ -class FunctionRegistry { - public: - /*! \brief definition of NArray function */ - typedef std::function Function; - /*! \brief registry entry */ - struct Entry { - /*! \brief function name */ - std::string name; - /*! \brief number of variable used by this function */ - unsigned num_use_vars; - /*! \brief number of variable mutated by this function */ - unsigned num_mutate_vars; - /*! \brief number of scalars used by this function */ - unsigned num_scalars; - /*! \brief information on how function should be called from API */ - int type_mask; - /*! \brief the real function */ - Function body; - /*! - * \brief constructor - * \param name name of the function - */ - explicit Entry(const std::string &name) - : name(name), - num_use_vars(0), - num_mutate_vars(0), - num_scalars(0), - type_mask(0), - body(nullptr) {} - /*! - * \brief set the number of mutate variables - * \param n number of mutate variablesx - * \return ref to the registered entry, used to set properties - */ - inline Entry &set_num_use_vars(unsigned n) { - num_use_vars = n; return *this; - } - /*! - * \brief set the number of mutate variables - * \param n number of mutate variablesx - * \return ref to the registered entry, used to set properties - */ - inline Entry &set_num_mutate_vars(unsigned n) { - num_mutate_vars = n; return *this; - } - /*! - * \brief set the number of scalar arguments - * \param n number of scalar arguments - * \return ref to the registered entry, used to set properties - */ - inline Entry &set_num_scalars(unsigned n) { - num_scalars = n; return *this; - } - /*! - * \brief set the function body - * \param f function body to set - * \return ref to the registered entry, used to set properties - */ - inline Entry &set_body(Function f) { - body = f; return *this; - } - /*! - * \brief set type mask - * \param tmask typemask - * \return ref to the registered entry, used to set properties - */ - inline Entry &set_type_mask(int tmask) { - type_mask = tmask; return *this; - } - /*! - * \brief set the function body to a binary NArray function - * this will also auto set the parameters correctly - * \param fbinary function body to set - * \return ref to the registered entry, used to set properties - */ - inline Entry &set_function(void fbinary(const NArray &lhs, - const NArray &rhs, - NArray *out)) { - body = [fbinary] (NArray **used_vars, - real_t *s, NArray **mutate_vars) { - fbinary(*used_vars[0], *used_vars[1], mutate_vars[0]); - }; - num_use_vars = 2; num_mutate_vars = 1; - type_mask = kNArrayArgBeforeScalar | kAcceptEmptyMutateTarget; - return *this; - } - /*! - * \brief set the function body to a unary NArray function - * this will also auto set the parameters correctly - * \param funary function body to set - * \return ref to the registered entry, used to set properties - */ - inline Entry &set_function(void funary(const NArray &src, - NArray *out)) { - body = [funary] (NArray **used_vars, - real_t *s, NArray **mutate_vars) { - funary(*used_vars[0], mutate_vars[0]); - }; - num_use_vars = 1; num_mutate_vars = 1; - type_mask = kNArrayArgBeforeScalar | kAcceptEmptyMutateTarget; - return *this; - } - /*! - * \brief invoke the function - * \param use_vars variables used by the function - * \param scalars the scalar arguments passed to function - * \param mutate_vars the variables mutated by the function - */ - inline void operator()(NArray **use_vars, - real_t *scalars, - NArray **mutate_vars) const { - body(use_vars, scalars, mutate_vars); - } - }; // Entry - /*! \return get a singleton */ - static FunctionRegistry *Get(); - /*! - * \brief register a name function under name - * \param name name of the function - * \return ref to the registered entry, used to set properties - */ - Entry &Register(const std::string name); - /*! \return list of functions in the registry */ - inline static const std::vector &List() { - return Get()->fun_list_; - } - /*! - * \brief find an function entry with corresponding name - * \param name name of the function - * \return the corresponding function, can be NULL - */ - inline static const Entry *Find(const std::string &name) { - auto &fmap = Get()->fmap_; - auto p = fmap.find(name); - if (p != fmap.end()) { - return p->second; - } else { - return nullptr; - } - } - - private: - /*! \brief list of functions */ - std::vector fun_list_; - /*! \brief map of name->function */ - std::map fmap_; - /*! \brief constructor */ - FunctionRegistry() {} - /*! \brief destructor */ - ~FunctionRegistry(); -}; - -/*! - * \brief macro to register NArray function - * - * Example: the following code is example to register aplus - * \code - * - * REGISTER_NARRAY_FUN(Plus) - * .set_body([] (NArray **used_vars, real_t *scalars, NArray **mutate_vars) { - * BinaryPlus(*used_vars[0], *used_vars[1], mutate_vars[0]); - * }) - * .set_num_use_vars(2) - * .set_num_mutate_vars(1); - * - * \endcode - */ -#define REGISTER_NARRAY_FUN(name) \ - static auto __ ## name ## _narray_fun__ = \ - ::mxnet::FunctionRegistry::Get()->Register("" # name) - -} // namespace mxnet -#endif // MXNET_API_REGISTRY_H_ diff --git a/include/mxnet/registry.h b/include/mxnet/registry.h new file mode 100644 index 000000000000..ab6501f60b86 --- /dev/null +++ b/include/mxnet/registry.h @@ -0,0 +1,222 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file registry.h + * \brief registry that registers all sorts of functions + */ +#ifndef MXNET_API_REGISTRY_H_ +#define MXNET_API_REGISTRY_H_ + +#include +#include +#include +#include +#include "./base.h" +#include "./narray.h" +#include "./symbol.h" + +namespace mxnet { + +/*! \brief registry template */ +template +class Registry { + public: + /*! \return get a singleton */ + static Registry *Get(); + /*! + * \brief register a name function under name + * \param name name of the function + * \return ref to the registered entry, used to set properties + */ + Entry &Register(const std::string& name); + /*! \return list of functions in the registry */ + inline static const std::vector &List() { + return Get()->fun_list_; + } + /*! + * \brief find an function entry with corresponding name + * \param name name of the function + * \return the corresponding function, can be NULL + */ + inline static const Entry *Find(const std::string &name) { + auto &fmap = Get()->fmap_; + auto p = fmap.find(name); + if (p != fmap.end()) { + return p->second; + } else { + return NULL; // c++11 is not required + } + } + + private: + /*! \brief list of functions */ + std::vector fun_list_; + /*! \brief map of name->function */ + std::map fmap_; + /*! \brief constructor */ + Registry() {} + /*! \brief destructor */ + ~Registry() { + for (typename std::map::iterator p = fmap_.begin(); + p != fmap_.end(); ++p) { + delete p->second; + } + } +}; + +/*! NArrayFunctionEntry requires c++11 */ +#if DMLC_USE_CXX11 +#include +/*! \brief mask information on how functions can be exposed */ +enum FunctionTypeMask { + /*! \brief all the use_vars should go before scalar */ + kNArrayArgBeforeScalar = 1, + /*! \brief all the scalar should go before use_vars */ + kScalarArgBeforeNArray = 1 << 1, + /*! + * \brief whether this function allows the handles in the target to + * be empty NArray that are not yet initialized, and will initialize + * them when the function is invoked. + * + * most function should support this, except copy between different + * devices, which requires the NArray to be pre-initialized with context + */ + kAcceptEmptyMutateTarget = 1 << 2 +}; + +/*! \brief registry entry */ +struct NArrayFunctionEntry { + /*! \brief definition of NArray function */ + typedef std::function Function; + /*! \brief function name */ + std::string name; + /*! \brief number of variable used by this function */ + unsigned num_use_vars; + /*! \brief number of variable mutated by this function */ + unsigned num_mutate_vars; + /*! \brief number of scalars used by this function */ + unsigned num_scalars; + /*! \brief information on how function should be called from API */ + int type_mask; + /*! \brief the real function */ + Function body; + /*! + * \brief constructor + * \param name name of the function + */ + explicit NArrayFunctionEntry(const std::string &name) + : name(name), + num_use_vars(0), + num_mutate_vars(0), + num_scalars(0), + type_mask(0), + body(nullptr) {} + /*! + * \brief set the number of mutate variables + * \param n number of mutate variablesx + * \return ref to the registered entry, used to set properties + */ + inline NArrayFunctionEntry &set_num_use_vars(unsigned n) { + num_use_vars = n; return *this; + } + /*! + * \brief set the number of mutate variables + * \param n number of mutate variablesx + * \return ref to the registered entry, used to set properties + */ + inline NArrayFunctionEntry &set_num_mutate_vars(unsigned n) { + num_mutate_vars = n; return *this; + } + /*! + * \brief set the number of scalar arguments + * \param n number of scalar arguments + * \return ref to the registered entry, used to set properties + */ + inline NArrayFunctionEntry &set_num_scalars(unsigned n) { + num_scalars = n; return *this; + } + /*! + * \brief set the function body + * \param f function body to set + * \return ref to the registered entry, used to set properties + */ + inline NArrayFunctionEntry &set_body(Function f) { + body = f; return *this; + } + /*! + * \brief set type mask + * \param tmask typemask + * \return ref to the registered entry, used to set properties + */ + inline NArrayFunctionEntry &set_type_mask(int tmask) { + type_mask = tmask; return *this; + } + /*! + * \brief set the function body to a binary NArray function + * this will also auto set the parameters correctly + * \param fbinary function body to set + * \return ref to the registered entry, used to set properties + */ + inline NArrayFunctionEntry &set_function(void fbinary(const NArray &lhs, + const NArray &rhs, + NArray *out)) { + body = [fbinary] (NArray **used_vars, + real_t *s, NArray **mutate_vars) { + fbinary(*used_vars[0], *used_vars[1], mutate_vars[0]); + }; + num_use_vars = 2; num_mutate_vars = 1; + type_mask = kNArrayArgBeforeScalar | kAcceptEmptyMutateTarget; + return *this; + } + /*! + * \brief set the function body to a unary NArray function + * this will also auto set the parameters correctly + * \param funary function body to set + * \return ref to the registered entry, used to set properties + */ + inline NArrayFunctionEntry &set_function(void funary(const NArray &src, + NArray *out)) { + body = [funary] (NArray **used_vars, + real_t *s, NArray **mutate_vars) { + funary(*used_vars[0], mutate_vars[0]); + }; + num_use_vars = 1; num_mutate_vars = 1; + type_mask = kNArrayArgBeforeScalar | kAcceptEmptyMutateTarget; + return *this; + } + /*! + * \brief invoke the function + * \param use_vars variables used by the function + * \param scalars the scalar arguments passed to function + * \param mutate_vars the variables mutated by the function + */ + inline void operator()(NArray **use_vars, + real_t *scalars, + NArray **mutate_vars) const { + body(use_vars, scalars, mutate_vars); + } +}; // NArrayFunctionEntry + +/*! + * \brief macro to register NArray function + * + * Example: the following code is example to register aplus + * \code + * + * REGISTER_NARRAY_FUN(Plus) + * .set_body([] (NArray **used_vars, real_t *scalars, NArray **mutate_vars) { + * BinaryPlus(*used_vars[0], *used_vars[1], mutate_vars[0]); + * }) + * .set_num_use_vars(2) + * .set_num_mutate_vars(1); + * + * \endcode + */ +#define REGISTER_NARRAY_FUN(name) \ + static auto __ ## name ## _narray_fun__ = \ + ::mxnet::Registry::Get()->Register("" # name) +#endif // DMLC_USE_CXX11 + +} // namespace mxnet +#endif // MXNET_API_REGISTRY_H_ diff --git a/src/api_registry.cc b/src/api_registry.cc deleted file mode 100644 index 0a6423441bbe..000000000000 --- a/src/api_registry.cc +++ /dev/null @@ -1,32 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file api_registry.cc - * \brief - */ -#include -#include -#include - -namespace mxnet { - -FunctionRegistry::Entry & -FunctionRegistry::Register(const std::string name) { - CHECK_EQ(fmap_.count(name), 0); - Entry *e = new Entry(name); - fmap_[name] = e; - fun_list_.push_back(e); - return *e; -} - -FunctionRegistry::~FunctionRegistry() { - for (auto p = fmap_.begin(); p != fmap_.end(); ++p) { - delete p->second; - } -} - -FunctionRegistry *FunctionRegistry::Get() { - static FunctionRegistry instance; - return &instance; -} - -} // namespace mxnet diff --git a/src/narray/narray.cc b/src/narray/narray.cc index c84cf3f05c44..831041bd1496 100644 --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -5,7 +5,7 @@ */ #include #include -#include +#include #include #include "./narray_op.h" diff --git a/src/registry.cc b/src/registry.cc new file mode 100644 index 000000000000..64c38f029b3a --- /dev/null +++ b/src/registry.cc @@ -0,0 +1,32 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file api_registry.cc + * \brief + */ +#include +#include +#include + +namespace mxnet { + +template +Entry &Registry::Register(const std::string& name) { + CHECK_EQ(fmap_.count(name), 0); + Entry *e = new Entry(name); + fmap_[name] = e; + fun_list_.push_back(e); + return *e; +} + +template +Registry *Registry::Get() { + static Registry instance; + return &instance; +} + +#if DMLC_USE_CXX11 +template NArrayFunctionEntry &Registry::Register(const std::string& name); +template Registry *Registry::Get(); +#endif + +} // namespace mxnet diff --git a/test/api_registry_test.cc b/test/api_registry_test.cc index 5ef43d4d4025..0f7cef3ba858 100644 --- a/test/api_registry_test.cc +++ b/test/api_registry_test.cc @@ -1,10 +1,10 @@ // Copyright (c) 2015 by Contributors // dummy code to test layer interface // used to demonstrate how interface can be used -#include +#include int main(int argc, char *argv[]) { - auto fadd = mxnet::FunctionRegistry::Find("Plus"); + auto fadd = mxnet::Registry::Find("Plus"); printf("f.name=%s\n", fadd->name.c_str()); return 0; } From 5bd0fb89ff9670c9830860588edea5b51b554ed4 Mon Sep 17 00:00:00 2001 From: linmin Date: Sat, 18 Jul 2015 21:33:48 +0800 Subject: [PATCH 2/3] move AtomicSymbol Registry to common Registry --- Makefile | 3 +- api/mxnet_api.cc | 54 ++++++++----- api/mxnet_api.h | 61 ++++++++------ api/python/mxnet/symbol_creator.py | 28 +++---- include/mxnet/atomic_symbol_registry.h | 108 ------------------------- include/mxnet/registry.h | 76 ++++++++++++++++- include/mxnet/symbol.h | 34 ++------ src/registry.cc | 22 +++++ src/symbol/atomic_symbol_registry.cc | 18 ----- src/symbol/symbol.cc | 37 +++++---- 10 files changed, 208 insertions(+), 233 deletions(-) delete mode 100644 include/mxnet/atomic_symbol_registry.h delete mode 100644 src/symbol/atomic_symbol_registry.cc diff --git a/Makefile b/Makefile index ab0f7a6b2f9d..33dc31fdddbc 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,7 @@ endif BIN = test/api_registry_test OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o mxnet_api.o registry.o symbol.o operator.o atomic_symbol_registry.o +OBJCXX11 = engine.o narray.o mxnet_api.o registry.o symbol.o operator.o CUOBJ = SLIB = api/libmxnet.so ALIB = api/libmxnet.a @@ -85,7 +85,6 @@ static_operator.o: src/static_operator/static_operator.cc static_operator_cpu.o: src/static_operator/static_operator_cpu.cc static_operator_gpu.o: src/static_operator/static_operator_gpu.cu symbol.o: src/symbol/symbol.cc -atomic_symbol_registry.o: src/symbol/atomic_symbol_registry.cc registry.o: src/registry.cc mxnet_api.o: api/mxnet_api.cc operator.o: src/operator/operator.cc diff --git a/api/mxnet_api.cc b/api/mxnet_api.cc index b619b3dc3fe7..9e4f9d242762 100644 --- a/api/mxnet_api.cc +++ b/api/mxnet_api.cc @@ -7,7 +7,8 @@ #include #include #include -#include +#include +#include #include #include #include "./mxnet_api.h" @@ -248,39 +249,50 @@ int MXFuncInvoke(FunctionHandle fun, API_END(); } -int MXSymCreate(const char *type_str, - int num_param, - const char** keys, - const char** vals, - SymbolHandle* out) { +int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator, + int num_param, + const char **keys, + const char **vals, + SymbolHandle *out) { API_BEGIN(); - CCreateSymbol(type_str, num_param, keys, vals, (Symbol**)out); // NOLINT(*) + AtomicSymbolEntry *e = static_cast(creator); + *out = static_cast(new Symbol); + AtomicSymbol *atomic_symbol = (*e)(); + for (int i = 0; i < num_param; ++i) { + atomic_symbol->SetParam(keys[i], vals[i]); + } + *static_cast(*out) = Symbol::Create(atomic_symbol); API_END(); } -int MXSymFree(SymbolHandle sym) { +int MXSymbolFree(SymbolHandle symbol) { API_BEGIN(); - delete static_cast(sym); + delete static_cast(symbol); API_END(); } -int MXSymDescribe(const char *type_str, - mx_uint *use_param) { +int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, + AtomicSymbolCreator **out_array) { API_BEGIN(); - *use_param = AtomicSymbolRegistry::Find(type_str)->use_param ? 1 : 0; + auto &vec = Registry::List(); + *out_size = static_cast(vec.size()); + *out_array = (AtomicSymbolCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*) API_END(); } -int MXListSyms(mx_uint *out_size, - const char ***out_array) { +int MXSymbolGetSingleton(AtomicSymbolCreator creator, + SymbolHandle *out) { API_BEGIN(); - auto &vec = AtomicSymbolRegistry::List(); - *out_size = static_cast(vec.size()); - std::vector type_strs; - for (auto entry : vec) { - type_strs.push_back(entry->type_str.c_str()); - } - *out_array = (const char**)(dmlc::BeginPtr(type_strs)); // NOLINT(*) + AtomicSymbolEntry *e = static_cast(creator); + *out = static_cast(e->GetSingletonSymbol()); + API_END(); +} + +int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, + const char **out) { + API_BEGIN(); + AtomicSymbolEntry *e = static_cast(creator); + *out = e->name.c_str(); API_END(); } diff --git a/api/mxnet_api.h b/api/mxnet_api.h index 4c357593abce..d53afb0b2386 100644 --- a/api/mxnet_api.h +++ b/api/mxnet_api.h @@ -29,9 +29,11 @@ typedef void *NArrayHandle; /*! \brief handle to a mxnet narray function that changes NArray */ typedef const void *FunctionHandle; /*! \brief handle to a function that takes param and creates symbol */ -typedef const void *SymbolCreatorHandle; +typedef void *AtomicSymbolCreator; /*! \brief handle to a symbol that can be bind as operator */ typedef void *SymbolHandle; +/*! \brief handle to a AtomicSymbol */ +typedef void *AtomicSymbolHandle; /*! \brief handle to a NArrayOperator */ typedef void *OperatorHandle; /*! \brief handle to a DataIterator */ @@ -212,45 +214,52 @@ MXNET_DLL int MXFuncInvoke(FunctionHandle fun, * \param out created symbol handle * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymCreateFromConfig(const char *cfg, - SymbolHandle *out); +MXNET_DLL int MXSymbolCreateFromConfig(const char *cfg, + SymbolHandle *out); /*! - * \brief invoke registered symbol creator through its handle. - * \param type_str the type of the AtomicSymbol - * \param num_param the number of the key value pairs in the param. - * \param keys an array of c str. - * \param vals the corresponding values of the keys. + * \brief create Symbol by wrapping AtomicSymbol + * \param creator the AtomicSymbolCreator + * \param num_param the number of parameters + * \param keys the keys to the params + * \param vals the vals of the params * \param out pointer to the created symbol handle * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymCreate(const char *type_str, - int num_param, - const char** keys, - const char** vals, - SymbolHandle* out); +MXNET_DLL int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator, + int num_param, + const char **keys, + const char **vals, + SymbolHandle *out); /*! * \brief free the symbol handle - * \param sym the symbol + * \param symbol the symbol * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymFree(SymbolHandle sym); +MXNET_DLL int MXSymbolFree(SymbolHandle symbol); /*! - * \brief query if the symbol creator needs param. - * \param type_str the type of the AtomicSymbol - * \param use_param describe if the symbol creator requires param + * \brief list all the available AtomicSymbolEntry + * \param out_size the size of returned array + * \param out_array the output AtomicSymbolCreator array * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymDescribe(const char *type_str, - mx_uint *use_param); +MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, + AtomicSymbolCreator **out_array); /*! - * \brief list all the available sym_creator - * most user can use it to list all the needed sym_creators - * \param out_size the size of returned array - * \param out_array the output sym_creators + * \brief get the singleton Symbol of the AtomicSymbol if any + * \param creator the AtomicSymbolCreator + * \param out the returned singleton Symbol of the AtomicSymbol the creator stands for + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolGetSingleton(AtomicSymbolCreator creator, + SymbolHandle *out); +/*! + * \brief get the singleton Symbol of the AtomicSymbol if any + * \param creator the AtomicSymbolCreator + * \param out the returned name of the creator * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXListSyms(mx_uint *out_size, - const char ***out_array); +MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator, + const char **out); /*! * \brief compose the symbol on other symbol * \param sym the symbol to apply diff --git a/api/python/mxnet/symbol_creator.py b/api/python/mxnet/symbol_creator.py index ce999e32b267..9950f8b04ed9 100644 --- a/api/python/mxnet/symbol_creator.py +++ b/api/python/mxnet/symbol_creator.py @@ -12,7 +12,7 @@ class _SymbolCreator(object): """SymbolCreator is a function that takes Param and return symbol""" - def __init__(self, name): + def __init__(self, name, handle): """Initialize the function with handle Parameters @@ -24,18 +24,17 @@ def __init__(self, name): the name of the function """ self.name = name - use_param = mx_uint() - check_call(_LIB.MXSymDescribe( - c_str(self.name), - ctypes.byref(use_param))) - self.use_param = use_param.value + self.handle = handle + singleton_ = ctypes.c_void_p() + check_call(_LIB.MXSymbolGetSingleton(self.handle, ctypes.byref(singleton_))) + self.singleton = Symbol(singleton_) def __call__(self, **kwargs): """Invoke creator of symbol by passing kwargs Parameters ---------- - params : kwargs + params : **kwargs provide the params necessary for the symbol creation Returns ------- @@ -44,8 +43,8 @@ def __call__(self, **kwargs): keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()]) sym_handle = SymbolHandle() - check_call(_LIB.MXSymCreate( - c_str(self.name), + check_call(_LIB.MXSymbolCreateFromAtomicSymbol( + self.handle, mx_uint(len(kwargs)), keys, vals, @@ -55,12 +54,13 @@ def __call__(self, **kwargs): class _SymbolCreatorRegistry(object): """Function Registry""" def __init__(self): - plist = ctypes.POINTER(ctypes.c_char_p)() + plist = ctypes.POINTER(ctypes.c_void_p)() size = ctypes.c_uint() - check_call(_LIB.MXListSyms(ctypes.byref(size), - ctypes.byref(plist))) + check_call(_LIB.MXSymbolListAtomicSymbolCreators(ctypes.byref(size), + ctypes.byref(plist))) hmap = {} + name = ctypes.c_char_p() for i in range(size.value): - name = plist[i] - hmap[name.value] = _SymbolCreator(name.value) + name = _LIB.MXSymbolGetAtomicSymbolName(plist[i], ctypes.byref(name)) + hmap[name.value] = _SymbolCreator(name.value, plist[i]) self.__dict__.update(hmap) diff --git a/include/mxnet/atomic_symbol_registry.h b/include/mxnet/atomic_symbol_registry.h deleted file mode 100644 index 8028bd2cfefa..000000000000 --- a/include/mxnet/atomic_symbol_registry.h +++ /dev/null @@ -1,108 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file atomic_symbol_registry.h - * \brief atomic symbol registry interface of mxnet - */ -#ifndef MXNET_ATOMIC_SYMBOL_REGISTRY_H_ -#define MXNET_ATOMIC_SYMBOL_REGISTRY_H_ -#include -// check c++11 -#if DMLC_USE_CXX11 == 0 -#error "cxx11 was required for symbol registry module" -#endif -#include -#include -#include -#include -#include -#include "./base.h" - -namespace mxnet { - -/*! - * \brief Register AtomicSymbol - */ -class AtomicSymbolRegistry { - public: - /*! \return get a singleton */ - static AtomicSymbolRegistry *Get(); - /*! \brief registered entry */ - struct Entry { - /*! \brief constructor */ - explicit Entry(const std::string& type_str) : type_str(type_str), - use_param(true), - body(nullptr) {} - /*! \brief type string of the entry */ - std::string type_str; - /*! \brief whether param is required */ - bool use_param; - /*! \brief function body to create AtomicSymbol */ - std::function body; - /*! - * \brief set if param is needed by this AtomicSymbol - */ - Entry &set_use_param(bool use_param) { - this->use_param = use_param; - return *this; - } - /*! - * \brief set the function body - */ - Entry &set_body(const std::function& body) { - this->body = body; - return *this; - } - }; - /*! - * \brief register the maker function with name - * \return the type string of the AtomicSymbol - */ - template - Entry &Register() { - AtomicType instance; - std::string type_str = instance.TypeString(); - Entry *e = new Entry(type_str); - fmap_[type_str] = e; - fun_list_.push_back(e); - e->set_body([]()->AtomicSymbol* { - return new AtomicType; - }); - return *e; - } - /*! - * \brief find the entry by type string - * \param type_str the type string of the AtomicSymbol - * \return the corresponding entry - */ - inline static const Entry* Find(const std::string& type_str) { - auto &fmap = Get()->fmap_; - auto p = fmap.find(type_str); - if (p != fmap.end()) { - return p->second; - } else { - return nullptr; - } - } - /*! \brief list all the AtomicSymbols */ - inline static const std::vector &List() { - return Get()->fun_list_; - } - /*! \brief make a atomicsymbol according to the typename */ - inline static AtomicSymbol* Make(const std::string& name) { - return Get()->fmap_[name]->body(); - } - - protected: - /*! \brief list of functions */ - std::vector fun_list_; - /*! \brief map of name->function */ - std::unordered_map fmap_; -}; - -/*! \brief macro to register AtomicSymbol to AtomicSymbolFactory */ -#define REGISTER_ATOMIC_SYMBOL(AtomicType) \ - static auto __## AtomicType ## _entry__ = \ - AtomicSymbolRegistry::Get()->Register() - -} // namespace mxnet -#endif // MXNET_ATOMIC_SYMBOL_REGISTRY_H_ diff --git a/include/mxnet/registry.h b/include/mxnet/registry.h index ab6501f60b86..7ec0f4c709f9 100644 --- a/include/mxnet/registry.h +++ b/include/mxnet/registry.h @@ -3,8 +3,8 @@ * \file registry.h * \brief registry that registers all sorts of functions */ -#ifndef MXNET_API_REGISTRY_H_ -#define MXNET_API_REGISTRY_H_ +#ifndef MXNET_REGISTRY_H_ +#define MXNET_REGISTRY_H_ #include #include @@ -218,5 +218,75 @@ struct NArrayFunctionEntry { ::mxnet::Registry::Get()->Register("" # name) #endif // DMLC_USE_CXX11 +class Symbol; +/*! \brief AtomicSymbolEntry to register */ +struct AtomicSymbolEntry { + /*! \brief typedef Creator function */ + typedef AtomicSymbol*(*Creator)(); + /*! \brief if AtomicSymbol use param */ + bool use_param; + /*! \brief name of the entry */ + std::string name; + /*! \brief function body to create AtomicSymbol */ + Creator body; + /*! \brief singleton is created when no param is needed for the AtomicSymbol */ + Symbol *singleton_symbol; + /*! \brief constructor */ + explicit AtomicSymbolEntry(const std::string& name) + : use_param(true), name(name), body(NULL), singleton_symbol(NULL) {} + /*! + * \brief set if param is needed by this AtomicSymbol + */ + inline AtomicSymbolEntry &set_use_param(bool use_param) { + this->use_param = use_param; + return *this; + } + /*! + * \brief set the function body + */ + inline AtomicSymbolEntry &set_body(Creator body) { + this->body = body; + return *this; + } + /*! + * \brief return the singleton symbol + */ + Symbol *GetSingletonSymbol(); + /*! \brief destructor */ + ~AtomicSymbolEntry(); + /*! + * \brief invoke the function + * \return the created AtomicSymbol + */ + inline AtomicSymbol* operator () () const { + return body(); + } + + private: + /*! \brief disable copy constructor */ + AtomicSymbolEntry(const AtomicSymbolEntry& other) {} + /*! \brief disable assignment operator */ + const AtomicSymbolEntry& operator = (const AtomicSymbolEntry& other) { return *this; } +}; + +/*! + * \brief macro to register AtomicSymbol to AtomicSymbolFactory + * + * Example: the following code is example to register aplus + * \code + * + * REGISTER_ATOMIC_SYMBOL(fullc) + * .set_use_param(false) + * + * \endcode + */ +#define REGISTER_ATOMIC_SYMBOL(name, AtomicSymbolType) \ + AtomicSymbol* __make_ ## AtomicSymbolType ## __() { \ + return new AtomicSymbolType; \ + } \ + static AtomicSymbolEntry& __ ## name ## _atomic_symbol__ = \ + ::mxnet::Registry::Get()->Register("" # name) \ + .set_body(__make_ ## AtomicSymbolType ## __) + } // namespace mxnet -#endif // MXNET_API_REGISTRY_H_ +#endif // MXNET_REGISTRY_H_ diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index a207158b8f27..6f0b5ef8af93 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -7,7 +7,7 @@ #define MXNET_SYMBOL_H_ #include -#include +#include #include #include #include @@ -98,36 +98,18 @@ class Symbol { * \return the arguments list of this symbol, they can be either named or unnamed (empty string). */ virtual std::vector ListArgs(); + /*! + * \brief create Symbol by wrapping AtomicSymbol + */ + static Symbol Create(AtomicSymbol* atomic_symbol); /*! * \brief create atomic symbol wrapped in symbol - * \param type_str the type string of the AtomicSymbol + * \param type_name the type string of the AtomicSymbol * \param param the parameter stored as key value pairs * \return the constructed Symbol */ - static Symbol CreateSymbol(const std::string& type_str, - const std::vector >& param) { - Symbol* s; - std::vector keys(param.size()); - std::vector vals(param.size()); - for (auto p : param) { - keys.push_back(p.first.c_str()); - vals.push_back(p.second.c_str()); - } - CCreateSymbol(type_str.c_str(), param.size(), &keys[0], &vals[0], &s); - Symbol ret = *s; - delete s; - return ret; - } - /*! - * \brief c api for CreateSymbol, this can be registered with SymbolCreatorRegistry - * \param type_str the type string of the AtomicSymbol - * \param num_param the number of params - * \param keys the key for the params - * \param vals values of the params - * \param out stores the returning symbol - */ - friend void CCreateSymbol(const char* type_str, int num_param, const char** keys, - const char** vals, Symbol** out); + static Symbol Create(const std::string& type_name, + const std::vector >& param); }; } // namespace mxnet diff --git a/src/registry.cc b/src/registry.cc index 64c38f029b3a..d51907dbdcd7 100644 --- a/src/registry.cc +++ b/src/registry.cc @@ -6,6 +6,7 @@ #include #include #include +#include namespace mxnet { @@ -29,4 +30,25 @@ template NArrayFunctionEntry &Registry::Register(const std: template Registry *Registry::Get(); #endif +Symbol *AtomicSymbolEntry::GetSingletonSymbol() { + if (singleton_symbol) { + return singleton_symbol; + } else if (body && !use_param) { + singleton_symbol = new Symbol; + *singleton_symbol = Symbol::Create(body()); + return singleton_symbol; + } else { + return NULL; + } +} + +AtomicSymbolEntry::~AtomicSymbolEntry() { + if (singleton_symbol) { + delete singleton_symbol; + } +} + +template AtomicSymbolEntry &Registry::Register(const std::string& name); +template Registry *Registry::Get(); + } // namespace mxnet diff --git a/src/symbol/atomic_symbol_registry.cc b/src/symbol/atomic_symbol_registry.cc deleted file mode 100644 index a9fa8f9b791a..000000000000 --- a/src/symbol/atomic_symbol_registry.cc +++ /dev/null @@ -1,18 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file symbol_registry.cc - * \brief symbol_registry of mxnet - */ -#include -#include -#include -#include - -namespace mxnet { - -AtomicSymbolRegistry *AtomicSymbolRegistry::Get() { - static AtomicSymbolRegistry instance; - return &instance; -} - -} // namespace mxnet diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index a8650df3c29c..a4d966ba422f 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -5,6 +5,7 @@ */ #include #include +#include #include namespace mxnet { @@ -138,28 +139,34 @@ std::vector Symbol::ListArgs() { return ret; } -void CCreateSymbol(const char* type_str, int num_param, const char** keys, const char** vals, - Symbol** out) { - Symbol* s = new Symbol; - AtomicSymbol* atom = AtomicSymbolRegistry::Make(type_str); - for (int i = 0; i < num_param; ++i) { - atom->SetParam(keys[i], vals[i]); - } - std::vector args = atom->DescribeArguments(); - std::vector rets = atom->DescribeReturns(); +Symbol Symbol::Create(AtomicSymbol *atomic_symbol) { + Symbol s; + std::vector args = atomic_symbol->DescribeArguments(); + std::vector rets = atomic_symbol->DescribeReturns(); // set head_ - s->head_ = std::make_shared(atom, ""); + s.head_ = std::make_shared(atomic_symbol, ""); // set index_ - s->index_ = rets.size() > 1 ? -1 : 0; + s.index_ = rets.size() > 1 ? -1 : 0; // set head_->in_index_ - s->head_->in_index_ = std::vector(args.size(), 0); + s.head_->in_index_ = std::vector(args.size(), 0); // set head_->in_symbol_ for (auto name : args) { - s->head_->in_symbol_.push_back(std::make_shared(nullptr, name)); + s.head_->in_symbol_.push_back(std::make_shared(nullptr, name)); } // set head_->out_shape_ - s->head_->out_shape_ = std::vector(rets.size()); - *out = s; + s.head_->out_shape_ = std::vector(rets.size()); + return s; +} + +Symbol Symbol::Create(const std::string& type_name, + const std::vector >& param) { + const AtomicSymbolEntry *entry = Registry::Find(type_name); + CHECK_NE(entry, NULL) << type_name << " is not a valid Symbol type"; + AtomicSymbol* atomic_symbol = (*entry)(); + for (auto p : param) { + atomic_symbol->SetParam(p.first.c_str(), p.second.c_str()); + } + return Create(atomic_symbol); } } // namespace mxnet From f2a4bc7991ca69c125b38ce79b891af394d52614 Mon Sep 17 00:00:00 2001 From: linmin Date: Sun, 19 Jul 2015 13:38:16 +0800 Subject: [PATCH 3/3] repair some bugs --- api/python/mxnet/symbol_creator.py | 9 ++++++--- include/mxnet/registry.h | 10 +++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/api/python/mxnet/symbol_creator.py b/api/python/mxnet/symbol_creator.py index 9950f8b04ed9..1f30ef8d8b91 100644 --- a/api/python/mxnet/symbol_creator.py +++ b/api/python/mxnet/symbol_creator.py @@ -25,16 +25,19 @@ def __init__(self, name, handle): """ self.name = name self.handle = handle - singleton_ = ctypes.c_void_p() + singleton_ = SymbolHandle() check_call(_LIB.MXSymbolGetSingleton(self.handle, ctypes.byref(singleton_))) - self.singleton = Symbol(singleton_) + if singleton_: + self.singleton = Symbol(singleton_) + else: + self.singleton = None def __call__(self, **kwargs): """Invoke creator of symbol by passing kwargs Parameters ---------- - params : **kwargs + **kwargs provide the params necessary for the symbol creation Returns ------- diff --git a/include/mxnet/registry.h b/include/mxnet/registry.h index 7ec0f4c709f9..1c27d2a3807c 100644 --- a/include/mxnet/registry.h +++ b/include/mxnet/registry.h @@ -38,8 +38,8 @@ class Registry { * \return the corresponding function, can be NULL */ inline static const Entry *Find(const std::string &name) { - auto &fmap = Get()->fmap_; - auto p = fmap.find(name); + const std::map &fmap = Get()->fmap_; + typename std::map::const_iterator p = fmap.find(name); if (p != fmap.end()) { return p->second; } else { @@ -281,11 +281,11 @@ struct AtomicSymbolEntry { * \endcode */ #define REGISTER_ATOMIC_SYMBOL(name, AtomicSymbolType) \ - AtomicSymbol* __make_ ## AtomicSymbolType ## __() { \ + ::mxnet::AtomicSymbol* __make_ ## AtomicSymbolType ## __() { \ return new AtomicSymbolType; \ } \ - static AtomicSymbolEntry& __ ## name ## _atomic_symbol__ = \ - ::mxnet::Registry::Get()->Register("" # name) \ + static ::mxnet::AtomicSymbolEntry& __ ## name ## _atomic_symbol__ = \ + ::mxnet::Registry<::mxnet::AtomicSymbolEntry>::Get()->Register("" # name) \ .set_body(__make_ ## AtomicSymbolType ## __) } // namespace mxnet