diff --git a/Makefile b/Makefile index 3af4ad9f14d3..0ad9dce5b7aa 100644 --- a/Makefile +++ b/Makefile @@ -58,7 +58,7 @@ endif BIN = test/api_registry_test OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o # add threaded engine after it is done -OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o symbol.o operator.o +OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o symbol.o operator.o atomic_symbol_registry.o CUOBJ = SLIB = api/libmxnet.so ALIB = api/libmxnet.a @@ -85,6 +85,7 @@ static_operator.o: src/static_operator/static_operator.cc static_operator_cpu.o: src/static_operator/static_operator_cpu.cc static_operator_gpu.o: src/static_operator/static_operator_gpu.cu symbol.o: src/symbol/symbol.cc +atomic_symbol_registry.o: src/symbol/atomic_symbol_registry.cc api_registry.o: src/api_registry.cc mxnet_api.o: api/mxnet_api.cc operator.o: src/operator/operator.cc diff --git a/api/mxnet_api.cc b/api/mxnet_api.cc index 7e351e52cb03..974c5dbfbb4c 100644 --- a/api/mxnet_api.cc +++ b/api/mxnet_api.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include "./mxnet_api.h" @@ -243,57 +244,43 @@ int MXFuncInvoke(FunctionHandle fun, auto *f = static_cast(fun); (*f)((NArray**)(use_vars), // NOLINT(*) scalar_args, - (NArray**)(mutate_vars)); // NOLINT(*) + (NArray**)(mutate_vars)); // NOLINT(*) API_END(); } -int MXSymFree(SymbolHandle sym) { +int MXSymCreate(const char *type_str, + int num_param, + const char** keys, + const char** vals, + SymbolHandle* out) { API_BEGIN(); - delete static_cast(sym); + CCreateSymbol(type_str, num_param, keys, vals, (Symbol**)out); // NOLINT(*) API_END(); } -int MXSymCreatorDescribe(SymbolCreatorHandle sym_creator, - mx_uint *use_param) { +int MXSymFree(SymbolHandle sym) { API_BEGIN(); - auto *sc = static_cast(sym_creator); - *use_param = sc->use_param ? 1 : 0; + delete static_cast(sym); API_END(); } -int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator, - int count, - const char** keys, - const char** vals, - SymbolHandle* out) { +int MXSymDescribe(const char *type_str, + mx_uint *use_param) { API_BEGIN(); - const SymbolCreatorRegistry::Entry *sc = - static_cast(sym_creator); - sc->body(count, keys, vals, (Symbol**)(out)); // NOLINT(*) + *use_param = AtomicSymbolRegistry::Find(type_str)->use_param ? 1 : 0; API_END(); } -int MXListSymCreators(mx_uint *out_size, - SymbolCreatorHandle **out_array) { +int MXListSyms(mx_uint *out_size, + const char ***out_array) { API_BEGIN(); - auto &vec = SymbolCreatorRegistry::List(); + auto &vec = AtomicSymbolRegistry::List(); *out_size = static_cast(vec.size()); - *out_array = (SymbolCreatorHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*) - API_END(); -} - -int MXGetSymCreator(const char *name, - SymbolCreatorHandle *out) { - API_BEGIN(); - *out = SymbolCreatorRegistry::Find(name); - API_END(); -} - -int MXSymCreatorGetName(SymbolCreatorHandle sym_creator, - const char **out_name) { - API_BEGIN(); - auto *f = static_cast(sym_creator); - *out_name = f->name.c_str(); + std::vector type_strs; + for (auto entry : vec) { + type_strs.push_back(entry->type_str.c_str()); + } + *out_array = (const char**)(dmlc::BeginPtr(type_strs)); // NOLINT(*) API_END(); } diff --git a/api/mxnet_api.h b/api/mxnet_api.h index fb4b8710e2e6..4c357593abce 100644 --- a/api/mxnet_api.h +++ b/api/mxnet_api.h @@ -214,6 +214,20 @@ MXNET_DLL int MXFuncInvoke(FunctionHandle fun, */ MXNET_DLL int MXSymCreateFromConfig(const char *cfg, SymbolHandle *out); +/*! + * \brief invoke registered symbol creator through its handle. + * \param type_str the type of the AtomicSymbol + * \param num_param the number of the key value pairs in the param. + * \param keys an array of c str. + * \param vals the corresponding values of the keys. + * \param out pointer to the created symbol handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymCreate(const char *type_str, + int num_param, + const char** keys, + const char** vals, + SymbolHandle* out); /*! * \brief free the symbol handle * \param sym the symbol @@ -222,26 +236,12 @@ MXNET_DLL int MXSymCreateFromConfig(const char *cfg, MXNET_DLL int MXSymFree(SymbolHandle sym); /*! * \brief query if the symbol creator needs param. - * \param sym_creator the symbol creator handle + * \param type_str the type of the AtomicSymbol * \param use_param describe if the symbol creator requires param * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXSymCreatorDescribe(SymbolCreatorHandle sym_creator, - mx_uint *use_param); -/*! - * \brief invoke registered symbol creator through its handle. - * \param sym_creator pointer to the symbolcreator function. - * \param count the number of the key value pairs in the param. - * \param keys an array of c str. - * \param vals the corresponding values of the keys. - * \param out pointer to the created symbol handle - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator, - int count, - const char** keys, - const char** vals, - SymbolHandle* out); +MXNET_DLL int MXSymDescribe(const char *type_str, + mx_uint *use_param); /*! * \brief list all the available sym_creator * most user can use it to list all the needed sym_creators @@ -249,24 +249,8 @@ MXNET_DLL int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator, * \param out_array the output sym_creators * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXListSymCreators(mx_uint *out_size, - SymbolCreatorHandle **out_array); -/*! - * \brief get the sym_creator by name - * \param name the name of the sym_creator - * \param out the corresponding sym_creator - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXGetSymCreator(const char *name, - SymbolCreatorHandle *out); -/*! - * \brief get the name of sym_creator handle - * \param sym_creator the sym_creator handle - * \param out_name the name of the sym_creator - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXSymCreatorGetName(SymbolCreatorHandle sym_creator, - const char **out_name); +MXNET_DLL int MXListSyms(mx_uint *out_size, + const char ***out_array); /*! * \brief compose the symbol on other symbol * \param sym the symbol to apply diff --git a/api/python/mxnet/symbol_creator.py b/api/python/mxnet/symbol_creator.py index e8a49149ec35..ce999e32b267 100644 --- a/api/python/mxnet/symbol_creator.py +++ b/api/python/mxnet/symbol_creator.py @@ -12,7 +12,7 @@ class _SymbolCreator(object): """SymbolCreator is a function that takes Param and return symbol""" - def __init__(self, handle, name): + def __init__(self, name): """Initialize the function with handle Parameters @@ -23,11 +23,10 @@ def __init__(self, handle, name): name : string the name of the function """ - self.handle = handle self.name = name use_param = mx_uint() - check_call(_LIB.MXSymCreatorDescribe( - self.handle, + check_call(_LIB.MXSymDescribe( + c_str(self.name), ctypes.byref(use_param))) self.use_param = use_param.value @@ -45,8 +44,8 @@ def __call__(self, **kwargs): keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()]) vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()]) sym_handle = SymbolHandle() - check_call(_LIB.MXSymCreatorInvoke( - self.handle, + check_call(_LIB.MXSymCreate( + c_str(self.name), mx_uint(len(kwargs)), keys, vals, @@ -56,14 +55,12 @@ def __call__(self, **kwargs): class _SymbolCreatorRegistry(object): """Function Registry""" def __init__(self): - plist = ctypes.POINTER(ctypes.c_void_p)() + plist = ctypes.POINTER(ctypes.c_char_p)() size = ctypes.c_uint() - check_call(_LIB.MXListSymCreators(ctypes.byref(size), - ctypes.byref(plist))) + check_call(_LIB.MXListSyms(ctypes.byref(size), + ctypes.byref(plist))) hmap = {} for i in range(size.value): - hdl = plist[i] - name = ctypes.c_char_p() - check_call(_LIB.MXSymCreatorGetName(hdl, ctypes.byref(name))) - hmap[name.value] = _SymbolCreator(hdl, name.value) + name = plist[i] + hmap[name.value] = _SymbolCreator(name.value) self.__dict__.update(hmap) diff --git a/include/mxnet/api_registry.h b/include/mxnet/api_registry.h index 91083c2ea11d..b2806932dabf 100644 --- a/include/mxnet/api_registry.h +++ b/include/mxnet/api_registry.h @@ -212,70 +212,5 @@ class FunctionRegistry { static auto __ ## name ## _narray_fun__ = \ ::mxnet::FunctionRegistry::Get()->Register("" # name) -/*! \brief registry of symbol creator */ -class SymbolCreatorRegistry { - public: - /*! \brief SymbolCreator is a function pointer */ - typedef void(*SymbolCreator)(int count, const char**, const char**, Symbol**); - /*! \return get a singleton */ - static SymbolCreatorRegistry *Get(); - /*! \brief keep the SymbolCreator function and its meta information */ - struct Entry { - /*! \brief the name of the symbol creator */ - std::string name; - /*! \brief the body of the function */ - SymbolCreator body; - /*! \brief if the creator requires params to construct */ - bool use_param; - /*! \brief constructor */ - explicit Entry(const std::string& name) : name(name), body(nullptr), use_param(true) {} - /*! \brief setter of body */ - inline Entry& set_body(SymbolCreator sc) { body = sc; return *this; } - /*! \brief setter of use_param */ - inline Entry& set_use_param(bool up) { use_param = up; return *this; } - }; - /*! - * \brief register a name symbol under name - * \param name name of the function - * \return ref to the registered entry, used to set properties - */ - Entry &Register(const std::string& name); - /*! \return list of functions in the registry */ - inline static const std::vector &List() { - return Get()->fun_list_; - } - /*! - * \brief find an symbolcreator entry with corresponding name - * \param name name of the symbolcreator - * \return the corresponding symbolcreator, can be NULL - */ - inline static const Entry *Find(const std::string &name) { - auto &fmap = Get()->fmap_; - auto p = fmap.find(name); - if (p != fmap.end()) { - return p->second; - } else { - return nullptr; - } - } - - private: - /*! \brief list of functions */ - std::vector fun_list_; - /*! \brief map of name->function */ - std::map fmap_; - /*! \brief constructor */ - SymbolCreatorRegistry() {} - /*! \brief destructor */ - ~SymbolCreatorRegistry(); -}; - -/*! - * \brief macro to register symbol creator - */ -#define REGISTER_SYMBOL_CREATOR(name) \ - static auto __ ## name ## _symbol_creator__ = \ - ::mxnet::SymbolCreatorRegistry::Get()->Register("" # name) - } // namespace mxnet #endif // MXNET_API_REGISTRY_H_ diff --git a/include/mxnet/atomic_symbol_registry.h b/include/mxnet/atomic_symbol_registry.h new file mode 100644 index 000000000000..8028bd2cfefa --- /dev/null +++ b/include/mxnet/atomic_symbol_registry.h @@ -0,0 +1,108 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file atomic_symbol_registry.h + * \brief atomic symbol registry interface of mxnet + */ +#ifndef MXNET_ATOMIC_SYMBOL_REGISTRY_H_ +#define MXNET_ATOMIC_SYMBOL_REGISTRY_H_ +#include +// check c++11 +#if DMLC_USE_CXX11 == 0 +#error "cxx11 was required for symbol registry module" +#endif +#include +#include +#include +#include +#include +#include "./base.h" + +namespace mxnet { + +/*! + * \brief Register AtomicSymbol + */ +class AtomicSymbolRegistry { + public: + /*! \return get a singleton */ + static AtomicSymbolRegistry *Get(); + /*! \brief registered entry */ + struct Entry { + /*! \brief constructor */ + explicit Entry(const std::string& type_str) : type_str(type_str), + use_param(true), + body(nullptr) {} + /*! \brief type string of the entry */ + std::string type_str; + /*! \brief whether param is required */ + bool use_param; + /*! \brief function body to create AtomicSymbol */ + std::function body; + /*! + * \brief set if param is needed by this AtomicSymbol + */ + Entry &set_use_param(bool use_param) { + this->use_param = use_param; + return *this; + } + /*! + * \brief set the function body + */ + Entry &set_body(const std::function& body) { + this->body = body; + return *this; + } + }; + /*! + * \brief register the maker function with name + * \return the type string of the AtomicSymbol + */ + template + Entry &Register() { + AtomicType instance; + std::string type_str = instance.TypeString(); + Entry *e = new Entry(type_str); + fmap_[type_str] = e; + fun_list_.push_back(e); + e->set_body([]()->AtomicSymbol* { + return new AtomicType; + }); + return *e; + } + /*! + * \brief find the entry by type string + * \param type_str the type string of the AtomicSymbol + * \return the corresponding entry + */ + inline static const Entry* Find(const std::string& type_str) { + auto &fmap = Get()->fmap_; + auto p = fmap.find(type_str); + if (p != fmap.end()) { + return p->second; + } else { + return nullptr; + } + } + /*! \brief list all the AtomicSymbols */ + inline static const std::vector &List() { + return Get()->fun_list_; + } + /*! \brief make a atomicsymbol according to the typename */ + inline static AtomicSymbol* Make(const std::string& name) { + return Get()->fmap_[name]->body(); + } + + protected: + /*! \brief list of functions */ + std::vector fun_list_; + /*! \brief map of name->function */ + std::unordered_map fmap_; +}; + +/*! \brief macro to register AtomicSymbol to AtomicSymbolFactory */ +#define REGISTER_ATOMIC_SYMBOL(AtomicType) \ + static auto __## AtomicType ## _entry__ = \ + AtomicSymbolRegistry::Get()->Register() + +} // namespace mxnet +#endif // MXNET_ATOMIC_SYMBOL_REGISTRY_H_ diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index 4afc89b700e2..f024774c0e3b 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -7,6 +7,7 @@ #define MXNET_SYMBOL_H_ #include +#include #include #include #include @@ -95,11 +96,12 @@ class Symbol { virtual std::vector ListArgs(); /*! * \brief create atomic symbol wrapped in symbol + * \param type_str the type string of the AtomicSymbol * \param param the parameter stored as key value pairs * \return the constructed Symbol */ - template - static Symbol CreateSymbol(const std::vector >& param) { + static Symbol CreateSymbol(const std::string& type_str, + const std::vector >& param) { Symbol* s; std::vector keys(param.size()); std::vector vals(param.size()); @@ -107,26 +109,27 @@ class Symbol { keys.push_back(p.first.c_str()); vals.push_back(p.second.c_str()); } - CreateSymbol(param.size(), &keys[0], &vals[0], &s); + CCreateSymbol(type_str.c_str(), param.size(), &keys[0], &vals[0], &s); Symbol ret = *s; delete s; return ret; } /*! * \brief c api for CreateSymbol, this can be registered with SymbolCreatorRegistry + * \param type_str the type string of the AtomicSymbol * \param num_param the number of params * \param keys the key for the params * \param vals values of the params * \param out stores the returning symbol */ - template - friend void CreateSymbol(int num_param, const char** keys, const char** vals, Symbol** out); + friend void CCreateSymbol(const char* type_str, int num_param, const char** keys, + const char** vals, Symbol** out); }; -template -void CreateSymbol(int num_param, const char** keys, const char** vals, Symbol** out) { +void CCreateSymbol(const char* type_str, int num_param, const char** keys, const char** vals, + Symbol** out) { Symbol* s = new Symbol; - Atomic* atom = new Atomic; + AtomicSymbol* atom = AtomicSymbolRegistry::Make(type_str); for (int i = 0; i < num_param; ++i) { atom->SetParam(keys[i], vals[i]); } diff --git a/src/api_registry.cc b/src/api_registry.cc index c029502287e9..0a6423441bbe 100644 --- a/src/api_registry.cc +++ b/src/api_registry.cc @@ -29,26 +29,4 @@ FunctionRegistry *FunctionRegistry::Get() { return &instance; } -// SymbolCreatorRegistry - -SymbolCreatorRegistry::Entry& -SymbolCreatorRegistry::Register(const std::string& name) { - CHECK_EQ(fmap_.count(name), 0); - Entry *e = new Entry(name); - fmap_[name] = e; - fun_list_.push_back(e); - return *e; -} - -SymbolCreatorRegistry::~SymbolCreatorRegistry() { - for (auto p = fmap_.begin(); p != fmap_.end(); ++p) { - delete p->second; - } -} - -SymbolCreatorRegistry *SymbolCreatorRegistry::Get() { - static SymbolCreatorRegistry instance; - return &instance; -} - } // namespace mxnet diff --git a/src/symbol/atomic_symbol_registry.cc b/src/symbol/atomic_symbol_registry.cc new file mode 100644 index 000000000000..a9fa8f9b791a --- /dev/null +++ b/src/symbol/atomic_symbol_registry.cc @@ -0,0 +1,18 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file symbol_registry.cc + * \brief symbol_registry of mxnet + */ +#include +#include +#include +#include + +namespace mxnet { + +AtomicSymbolRegistry *AtomicSymbolRegistry::Get() { + static AtomicSymbolRegistry instance; + return &instance; +} + +} // namespace mxnet