Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

clean up registry code #6

Merged
merged 3 commits into from
Jul 19, 2015
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ endif
BIN = test/api_registry_test
OBJ = storage.o narray_op_cpu.o static_operator.o static_operator_cpu.o
# add threaded engine after it is done
OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o symbol.o operator.o atomic_symbol_registry.o
OBJCXX11 = engine.o narray.o mxnet_api.o registry.o symbol.o operator.o
CUOBJ =
SLIB = api/libmxnet.so
ALIB = api/libmxnet.a
Expand All @@ -85,8 +85,7 @@ static_operator.o: src/static_operator/static_operator.cc
static_operator_cpu.o: src/static_operator/static_operator_cpu.cc
static_operator_gpu.o: src/static_operator/static_operator_gpu.cu
symbol.o: src/symbol/symbol.cc
atomic_symbol_registry.o: src/symbol/atomic_symbol_registry.cc
api_registry.o: src/api_registry.cc
registry.o: src/registry.cc
mxnet_api.o: api/mxnet_api.cc
operator.o: src/operator/operator.cc

Expand Down
66 changes: 39 additions & 27 deletions api/mxnet_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
#include <dmlc/logging.h>
#include <mxnet/base.h>
#include <mxnet/narray.h>
#include <mxnet/atomic_symbol_registry.h>
#include <mxnet/api_registry.h>
#include <mxnet/symbol.h>
#include <mxnet/atomic_symbol.h>
#include <mxnet/registry.h>
#include <mutex>
#include "./mxnet_api.h"

Expand Down Expand Up @@ -201,7 +202,7 @@ int MXNArrayGetContext(NArrayHandle handle,
int MXListFunctions(mx_uint *out_size,
FunctionHandle **out_array) {
API_BEGIN();
auto &vec = FunctionRegistry::List();
auto &vec = Registry<NArrayFunctionEntry>::List();
*out_size = static_cast<mx_uint>(vec.size());
*out_array = (FunctionHandle*)(dmlc::BeginPtr(vec)); // NOLINT(*)
API_END();
Expand All @@ -210,14 +211,14 @@ int MXListFunctions(mx_uint *out_size,
int MXGetFunction(const char *name,
FunctionHandle *out) {
API_BEGIN();
*out = FunctionRegistry::Find(name);
*out = Registry<NArrayFunctionEntry>::Find(name);
API_END();
}

int MXFuncGetName(FunctionHandle fun,
const char **out_name) {
API_BEGIN();
auto *f = static_cast<const FunctionRegistry::Entry *>(fun);
auto *f = static_cast<const NArrayFunctionEntry*>(fun);
*out_name = f->name.c_str();
API_END();
}
Expand All @@ -228,7 +229,7 @@ int MXFuncDescribe(FunctionHandle fun,
mx_uint *num_mutate_vars,
int *type_mask) {
API_BEGIN();
auto *f = static_cast<const FunctionRegistry::Entry *>(fun);
auto *f = static_cast<const NArrayFunctionEntry*>(fun);
*num_use_vars = f->num_use_vars;
*num_scalars = f->num_scalars;
*num_mutate_vars = f->num_mutate_vars;
Expand All @@ -241,46 +242,57 @@ int MXFuncInvoke(FunctionHandle fun,
mx_float *scalar_args,
NArrayHandle *mutate_vars) {
API_BEGIN();
auto *f = static_cast<const FunctionRegistry::Entry *>(fun);
auto *f = static_cast<const NArrayFunctionEntry*>(fun);
(*f)((NArray**)(use_vars), // NOLINT(*)
scalar_args,
(NArray**)(mutate_vars)); // NOLINT(*)
API_END();
}

int MXSymCreate(const char *type_str,
int num_param,
const char** keys,
const char** vals,
SymbolHandle* out) {
int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator,
int num_param,
const char **keys,
const char **vals,
SymbolHandle *out) {
API_BEGIN();
CCreateSymbol(type_str, num_param, keys, vals, (Symbol**)out); // NOLINT(*)
AtomicSymbolEntry *e = static_cast<AtomicSymbolEntry *>(creator);
*out = static_cast<SymbolHandle>(new Symbol);
AtomicSymbol *atomic_symbol = (*e)();
for (int i = 0; i < num_param; ++i) {
atomic_symbol->SetParam(keys[i], vals[i]);
}
*static_cast<Symbol*>(*out) = Symbol::Create(atomic_symbol);
API_END();
}

int MXSymFree(SymbolHandle sym) {
int MXSymbolFree(SymbolHandle symbol) {
API_BEGIN();
delete static_cast<Symbol*>(sym);
delete static_cast<Symbol*>(symbol);
API_END();
}

int MXSymDescribe(const char *type_str,
mx_uint *use_param) {
int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
AtomicSymbolCreator **out_array) {
API_BEGIN();
*use_param = AtomicSymbolRegistry::Find(type_str)->use_param ? 1 : 0;
auto &vec = Registry<AtomicSymbolEntry>::List();
*out_size = static_cast<mx_uint>(vec.size());
*out_array = (AtomicSymbolCreator*)(dmlc::BeginPtr(vec)); // NOLINT(*)
API_END();
}

int MXListSyms(mx_uint *out_size,
const char ***out_array) {
int MXSymbolGetSingleton(AtomicSymbolCreator creator,
SymbolHandle *out) {
API_BEGIN();
auto &vec = AtomicSymbolRegistry::List();
*out_size = static_cast<mx_uint>(vec.size());
std::vector<const char*> type_strs;
for (auto entry : vec) {
type_strs.push_back(entry->type_str.c_str());
}
*out_array = (const char**)(dmlc::BeginPtr(type_strs)); // NOLINT(*)
AtomicSymbolEntry *e = static_cast<AtomicSymbolEntry *>(creator);
*out = static_cast<SymbolHandle>(e->GetSingletonSymbol());
API_END();
}

int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
const char **out) {
API_BEGIN();
AtomicSymbolEntry *e = static_cast<AtomicSymbolEntry *>(creator);
*out = e->name.c_str();
API_END();
}

Expand Down
61 changes: 35 additions & 26 deletions api/mxnet_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ typedef void *NArrayHandle;
/*! \brief handle to a mxnet narray function that changes NArray */
typedef const void *FunctionHandle;
/*! \brief handle to a function that takes param and creates symbol */
typedef const void *SymbolCreatorHandle;
typedef void *AtomicSymbolCreator;
/*! \brief handle to a symbol that can be bind as operator */
typedef void *SymbolHandle;
/*! \brief handle to a AtomicSymbol */
typedef void *AtomicSymbolHandle;
/*! \brief handle to a NArrayOperator */
typedef void *OperatorHandle;
/*! \brief handle to a DataIterator */
Expand Down Expand Up @@ -212,45 +214,52 @@ MXNET_DLL int MXFuncInvoke(FunctionHandle fun,
* \param out created symbol handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymCreateFromConfig(const char *cfg,
SymbolHandle *out);
MXNET_DLL int MXSymbolCreateFromConfig(const char *cfg,
SymbolHandle *out);
/*!
* \brief invoke registered symbol creator through its handle.
* \param type_str the type of the AtomicSymbol
* \param num_param the number of the key value pairs in the param.
* \param keys an array of c str.
* \param vals the corresponding values of the keys.
* \brief create Symbol by wrapping AtomicSymbol
* \param creator the AtomicSymbolCreator
* \param num_param the number of parameters
* \param keys the keys to the params
* \param vals the vals of the params
* \param out pointer to the created symbol handle
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymCreate(const char *type_str,
int num_param,
const char** keys,
const char** vals,
SymbolHandle* out);
MXNET_DLL int MXSymbolCreateFromAtomicSymbol(AtomicSymbolCreator creator,
int num_param,
const char **keys,
const char **vals,
SymbolHandle *out);
/*!
* \brief free the symbol handle
* \param sym the symbol
* \param symbol the symbol
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymFree(SymbolHandle sym);
MXNET_DLL int MXSymbolFree(SymbolHandle symbol);
/*!
* \brief query if the symbol creator needs param.
* \param type_str the type of the AtomicSymbol
* \param use_param describe if the symbol creator requires param
* \brief list all the available AtomicSymbolEntry
* \param out_size the size of returned array
* \param out_array the output AtomicSymbolCreator array
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymDescribe(const char *type_str,
mx_uint *use_param);
MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size,
AtomicSymbolCreator **out_array);
/*!
* \brief list all the available sym_creator
* most user can use it to list all the needed sym_creators
* \param out_size the size of returned array
* \param out_array the output sym_creators
* \brief get the singleton Symbol of the AtomicSymbol if any
* \param creator the AtomicSymbolCreator
* \param out the returned singleton Symbol of the AtomicSymbol the creator stands for
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGetSingleton(AtomicSymbolCreator creator,
SymbolHandle *out);
/*!
* \brief get the singleton Symbol of the AtomicSymbol if any
* \param creator the AtomicSymbolCreator
* \param out the returned name of the creator
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXListSyms(mx_uint *out_size,
const char ***out_array);
MXNET_DLL int MXSymbolGetAtomicSymbolName(AtomicSymbolCreator creator,
const char **out);
/*!
* \brief compose the symbol on other symbol
* \param sym the symbol to apply
Expand Down
28 changes: 14 additions & 14 deletions api/python/mxnet/symbol_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class _SymbolCreator(object):
"""SymbolCreator is a function that takes Param and return symbol"""

def __init__(self, name):
def __init__(self, name, handle):
"""Initialize the function with handle

Parameters
Expand All @@ -24,18 +24,17 @@ def __init__(self, name):
the name of the function
"""
self.name = name
use_param = mx_uint()
check_call(_LIB.MXSymDescribe(
c_str(self.name),
ctypes.byref(use_param)))
self.use_param = use_param.value
self.handle = handle
singleton_ = ctypes.c_void_p()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to typedef SymbolHandle instead of c_void_p, see NArrayHandle type

check_call(_LIB.MXSymbolGetSingleton(self.handle, ctypes.byref(singleton_)))
self.singleton = Symbol(singleton_)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if use_param==true? consider handle such case


def __call__(self, **kwargs):
"""Invoke creator of symbol by passing kwargs

Parameters
----------
params : kwargs
params : **kwargs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

provide the params necessary for the symbol creation
Returns
-------
Expand All @@ -44,8 +43,8 @@ def __call__(self, **kwargs):
keys = c_array(ctypes.c_char_p, [c_str(key) for key in kwargs.keys()])
vals = c_array(ctypes.c_char_p, [c_str(str(val)) for val in kwargs.values()])
sym_handle = SymbolHandle()
check_call(_LIB.MXSymCreate(
c_str(self.name),
check_call(_LIB.MXSymbolCreateFromAtomicSymbol(
self.handle,
mx_uint(len(kwargs)),
keys,
vals,
Expand All @@ -55,12 +54,13 @@ def __call__(self, **kwargs):
class _SymbolCreatorRegistry(object):
"""Function Registry"""
def __init__(self):
plist = ctypes.POINTER(ctypes.c_char_p)()
plist = ctypes.POINTER(ctypes.c_void_p)()
size = ctypes.c_uint()
check_call(_LIB.MXListSyms(ctypes.byref(size),
ctypes.byref(plist)))
check_call(_LIB.MXSymbolListAtomicSymbolCreators(ctypes.byref(size),
ctypes.byref(plist)))
hmap = {}
name = ctypes.c_char_p()
for i in range(size.value):
name = plist[i]
hmap[name.value] = _SymbolCreator(name.value)
name = _LIB.MXSymbolGetAtomicSymbolName(plist[i], ctypes.byref(name))
hmap[name.value] = _SymbolCreator(name.value, plist[i])
self.__dict__.update(hmap)
Loading