Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

change capi #5

Merged
merged 6 commits into from
Jul 17, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ endif
BIN = test/api_registry_test
OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o
# add threaded engine after it is done
OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o symbol.o operator.o
OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o symbol.o operator.o atomic_symbol_registry.o
CUOBJ =
SLIB = api/libmxnet.so
ALIB = api/libmxnet.a
Expand All @@ -85,6 +85,7 @@ static_operator.o: src/static_operator/static_operator.cc
static_operator_cpu.o: src/static_operator/static_operator_cpu.cc
static_operator_gpu.o: src/static_operator/static_operator_gpu.cu
symbol.o: src/symbol/symbol.cc
atomic_symbol_registry.o: src/symbol/atomic_symbol_registry.cc
api_registry.o: src/api_registry.cc
mxnet_api.o: api/mxnet_api.cc
operator.o: src/operator/operator.cc
Expand Down
55 changes: 21 additions & 34 deletions api/mxnet_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <dmlc/logging.h>
#include <mxnet/base.h>
#include <mxnet/narray.h>
#include <mxnet/atomic_symbol_registry.h>
#include <mxnet/api_registry.h>
#include <mutex>
#include "./mxnet_api.h"
Expand Down Expand Up @@ -243,57 +244,43 @@ int MXFuncInvoke(FunctionHandle fun,
auto *f = static_cast<const FunctionRegistry::Entry *>(fun);
(*f)((NArray**)(use_vars), // NOLINT(*)
scalar_args,
(NArray**)(mutate_vars)); // NOLINT(*)
(NArray**)(mutate_vars)); // NOLINT(*)
API_END();
}

int MXSymFree(SymbolHandle sym) {
int MXSymCreate(const char *type_str,
int num_param,
const char** keys,
const char** vals,
SymbolHandle* out) {
API_BEGIN();
delete static_cast<Symbol*>(sym);
CCreateSymbol(type_str, num_param, keys, vals, (Symbol**)out); // NOLINT(*)
API_END();
}

int MXSymCreatorDescribe(SymbolCreatorHandle sym_creator,
mx_uint *use_param) {
int MXSymFree(SymbolHandle sym) {
API_BEGIN();
auto *sc = static_cast<const SymbolCreatorRegistry::Entry *>(sym_creator);
*use_param = sc->use_param ? 1 : 0;
delete static_cast<Symbol*>(sym);
API_END();
}

int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator,
int count,
const char** keys,
const char** vals,
SymbolHandle* out) {
int MXSymDescribe(const char *type_str,
mx_uint *use_param) {
API_BEGIN();
const SymbolCreatorRegistry::Entry *sc =
static_cast<const SymbolCreatorRegistry::Entry *>(sym_creator);
sc->body(count, keys, vals, (Symbol**)(out)); // NOLINT(*)
*use_param = AtomicSymbolRegistry::Find(type_str)->use_param ? 1 : 0;
API_END();
}

int MXListSymCreators(mx_uint *out_size,
SymbolCreatorHandle **out_array) {
int MXListSyms(mx_uint *out_size,
const char ***out_array) {
API_BEGIN();
auto &vec = SymbolCreatorRegistry::List();
auto &vec = AtomicSymbolRegistry::List();
*out_size = static_cast<mx_uint>(vec.size());
*out_array = (SymbolCreatorHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*)
API_END();
}

int MXGetSymCreator(const char *name,
SymbolCreatorHandle *out) {
API_BEGIN();
*out = SymbolCreatorRegistry::Find(name);
API_END();
}

int MXSymCreatorGetName(SymbolCreatorHandle sym_creator,
const char **out_name) {
API_BEGIN();
auto *f = static_cast<const SymbolCreatorRegistry::Entry *>(sym_creator);
*out_name = f->name.c_str();
std::vector<const char*> type_strs;
for (auto entry : vec) {
type_strs.push_back(entry->type_str.c_str());
}
*out_array = (const char**)(dmlc::BeginPtr(type_strs)); // NOLINT(*)
API_END();
}

Expand Down
54 changes: 19 additions & 35 deletions api/mxnet_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,20 @@ MXNET_DLL int MXFuncInvoke(FunctionHandle fun,
*/
MXNET_DLL int MXSymCreateFromConfig(const char *cfg,
SymbolHandle *out);
/*!
* \brief invoke registered symbol creator through its handle.
* \param type_str the type of the AtomicSymbol
* \param num_param the number of the key value pairs in the param.
* \param keys an array of c str.
* \param vals the corresponding values of the keys.
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymCreate(const char *type_str,
int num_param,
const char** keys,
const char** vals,
SymbolHandle* out);
/*!
* \brief free the symbol handle
* \param sym the symbol
Expand All @@ -222,51 +236,21 @@ MXNET_DLL int MXSymCreateFromConfig(const char *cfg,
MXNET_DLL int MXSymFree(SymbolHandle sym);
/*!
* \brief query if the symbol creator needs param.
* \param sym_creator the symbol creator handle
* \param type_str the type of the AtomicSymbol
* \param use_param describe if the symbol creator requires param
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymCreatorDescribe(SymbolCreatorHandle sym_creator,
mx_uint *use_param);
/*!
* \brief invoke registered symbol creator through its handle.
* \param sym_creator pointer to the symbolcreator function.
* \param count the number of the key value pairs in the param.
* \param keys an array of c str.
* \param vals the corresponding values of the keys.
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymCreatorInvoke(SymbolCreatorHandle sym_creator,
int count,
const char** keys,
const char** vals,
SymbolHandle* out);
MXNET_DLL int MXSymDescribe(const char *type_str,
mx_uint *use_param);
/*!
* \brief list all the available sym_creator
* most user can use it to list all the needed sym_creators
* \param out_size the size of returned array
* \param out_array the output sym_creators
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXListSymCreators(mx_uint *out_size,
SymbolCreatorHandle **out_array);
/*!
* \brief get the sym_creator by name
* \param name the name of the sym_creator
* \param out the corresponding sym_creator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXGetSymCreator(const char *name,
SymbolCreatorHandle *out);
/*!
* \brief get the name of sym_creator handle
* \param sym_creator the sym_creator handle
* \param out_name the name of the sym_creator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymCreatorGetName(SymbolCreatorHandle sym_creator,
const char **out_name);
MXNET_DLL int MXListSyms(mx_uint *out_size,
const char ***out_array);
/*!
* \brief compose the symbol on other symbol
* \param sym the symbol to apply
Expand Down
23 changes: 10 additions & 13 deletions api/python/mxnet/symbol_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class _SymbolCreator(object):
"""SymbolCreator is a function that takes Param and return symbol"""

def __init__(self, handle, name):
def __init__(self, name):
"""Initialize the function with handle

Parameters
Expand All @@ -23,11 +23,10 @@ def __init__(self, handle, name):
name : string
the name of the function
"""
self.handle = handle
self.name = name
use_param = mx_uint()
check_call(_LIB.MXSymCreatorDescribe(
self.handle,
check_call(_LIB.MXSymDescribe(
c_str(self.name),
ctypes.byref(use_param)))
self.use_param = use_param.value

Expand All @@ -45,8 +44,8 @@ def __call__(self, **kwargs):
keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()])
vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()])
sym_handle = SymbolHandle()
check_call(_LIB.MXSymCreatorInvoke(
self.handle,
check_call(_LIB.MXSymCreate(
c_str(self.name),
mx_uint(len(kwargs)),
keys,
vals,
Expand All @@ -56,14 +55,12 @@ def __call__(self, **kwargs):
class _SymbolCreatorRegistry(object):
"""Function Registry"""
def __init__(self):
plist = ctypes.POINTER(ctypes.c_void_p)()
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()
check_call(_LIB.MXListSymCreators(ctypes.byref(size),
ctypes.byref(plist)))
check_call(_LIB.MXListSyms(ctypes.byref(size),
ctypes.byref(plist)))
hmap = {}
for i in range(size.value):
hdl = plist[i]
name = ctypes.c_char_p()
check_call(_LIB.MXSymCreatorGetName(hdl, ctypes.byref(name)))
hmap[name.value] = _SymbolCreator(hdl, name.value)
name = plist[i]
hmap[name.value] = _SymbolCreator(name.value)
self.__dict__.update(hmap)
65 changes: 0 additions & 65 deletions include/mxnet/api_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,70 +212,5 @@ class FunctionRegistry {
static auto __ ## name ## _narray_fun__ = \
::mxnet::FunctionRegistry::Get()->Register("" # name)

/*! \brief registry of symbol creator */
class SymbolCreatorRegistry {
public:
/*! \brief SymbolCreator is a function pointer */
typedef void(*SymbolCreator)(int count, const char**, const char**, Symbol**);
/*! \return get a singleton */
static SymbolCreatorRegistry *Get();
/*! \brief keep the SymbolCreator function and its meta information */
struct Entry {
/*! \brief the name of the symbol creator */
std::string name;
/*! \brief the body of the function */
SymbolCreator body;
/*! \brief if the creator requires params to construct */
bool use_param;
/*! \brief constructor */
explicit Entry(const std::string& name) : name(name), body(nullptr), use_param(true) {}
/*! \brief setter of body */
inline Entry& set_body(SymbolCreator sc) { body = sc; return *this; }
/*! \brief setter of use_param */
inline Entry& set_use_param(bool up) { use_param = up; return *this; }
};
/*!
* \brief register a name symbol under name
* \param name name of the function
* \return ref to the registered entry, used to set properties
*/
Entry &Register(const std::string& name);
/*! \return list of functions in the registry */
inline static const std::vector<const Entry*> &List() {
return Get()->fun_list_;
}
/*!
* \brief find an symbolcreator entry with corresponding name
* \param name name of the symbolcreator
* \return the corresponding symbolcreator, can be NULL
*/
inline static const Entry *Find(const std::string &name) {
auto &fmap = Get()->fmap_;
auto p = fmap.find(name);
if (p != fmap.end()) {
return p->second;
} else {
return nullptr;
}
}

private:
/*! \brief list of functions */
std::vector<const Entry*> fun_list_;
/*! \brief map of name->function */
std::map<std::string, Entry*> fmap_;
/*! \brief constructor */
SymbolCreatorRegistry() {}
/*! \brief destructor */
~SymbolCreatorRegistry();
};

/*!
* \brief macro to register symbol creator
*/
#define REGISTER_SYMBOL_CREATOR(name) \
static auto __ ## name ## _symbol_creator__ = \
::mxnet::SymbolCreatorRegistry::Get()->Register("" # name)

} // namespace mxnet
#endif // MXNET_API_REGISTRY_H_
5 changes: 5 additions & 0 deletions include/mxnet/atomic_symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ class AtomicSymbol {
* Calling bind from the Symbol wrapper would generate a NArrayOperator.
*/
virtual Operator* Bind(Context ctx) const = 0;
/*!
* \brief return the type string of the atomic symbol
* subclasses override this function.
*/
virtual std::string TypeString() const = 0;
friend class Symbol;
};

Expand Down
Loading