From 4e0e413e34fa56757feebcd56a1babb982c8ff50 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Thu, 20 Aug 2015 12:15:18 -0600 Subject: [PATCH 1/2] Read docs --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index b1ef53bc6148..4659ce9f0413 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # MXNet [![Build Status](https://travis-ci.org/dmlc/mxnet.svg?branch=master)](https://travis-ci.org/dmlc/mxnet) +[![Documentation Status](https://readthedocs.org/projects/mxnet/badge/?version=latest)](https://readthedocs.org/projects/mxnet/?badge=latest) This is a project that combines lessons and ideas we learnt from [cxxnet](https://github.com/dmlc/cxxnet), [minerva](https://github.com/dmlc/minerva) and [purine2](https://github.com/purine/purine2). - The interface is designed in collaboration by authors of three projects. From ef7bb0674182792195e79b35d507ad2a9ff2136f Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Thu, 20 Aug 2015 23:17:39 -0600 Subject: [PATCH 2/2] simplify symbol creator as discussed --- python/mxnet/__init__.py | 5 +- python/mxnet/symbol.py | 130 +++++++++++++++++++++++++++----- python/mxnet/symbol_creator.py | 132 --------------------------------- python/test_mnist.py | 8 +- 4 files changed, 119 insertions(+), 156 deletions(-) delete mode 100644 python/mxnet/symbol_creator.py diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 2a70190fd3cd..94b71bce16cc 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -12,11 +12,10 @@ from .context import Context, current_context from .narray import NArray from .function import _FunctionRegistry -from .symbol import Symbol -from .symbol_creator import _SymbolCreatorRegistry +from . import symbol __version__ = "0.1.0" # this is a global function registry that can be used to invoke functions op = NArray._init_function_registry(_FunctionRegistry()) -sym = Symbol._init_symbol_creator_registry(_SymbolCreatorRegistry()) + diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index b4f8cd1b7914..cb5daf2225a1 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -4,30 +4,16 @@ from __future__ import absolute_import import ctypes +import sys from .base import _LIB -from .base import c_array, c_str, mx_uint, NArrayHandle, ExecutorHandle, SymbolHandle +from .base import c_array, c_str, mx_uint, string_types +from .base import NArrayHandle, ExecutorHandle, SymbolHandle from .base import check_call from .context import Context from .executor import Executor class Symbol(object): """Symbol is symbolic graph of the mxnet.""" - _registry = None - - @staticmethod - def _init_symbol_creator_registry(symbol_creator_registry): - """Initialize symbol creator registry - - Parameters - ---------- - symbol_creator_registry: - pass in symbol_creator_registry - Returns - ------- - the passed in registry - """ - _registry = symbol_creator_registry - return _registry def __init__(self, handle): """Initialize the function with handle @@ -257,3 +243,113 @@ def bind(self, ctx, args, args_grad, reqs): reqs_array, ctypes.byref(handle))) return Executor(handle) + + +def Variable(name): + """Create a symbolic variable with specified name. + + Parameters + ---------- + name : str + Name of the variable. + + Returns + ------- + variable : Symbol + The created variable symbol. + """ + if not isinstance(name, string_types): + raise TypeError('Expect a string for variable `name`') + handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateVariable(name, ctypes.byref(handle))) + return Symbol(handle) + + +def Group(symbols): + """Create a symbolic variable that groups several symbols together. + + Parameters + ---------- + symbols : list + List of symbols to be grouped. + + Returns + ------- + sym : Symbol + The created group symbol. + """ + ihandles = [] + for sym in symbols: + if not isinstance(sym, Symbol): + raise TypeError('Expect Symbols in the list input') + ihandles.append(sym.handle) + handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateGroup( + len(ihandles), c_array(SymbolHandle, ihandles), ctypes.byref(handle))) + return Symbol(handle) + + +def _make_atomic_symbol_function(handle, func_name): + """Create an atomic symbol function by handle and funciton name.""" + def creator(*args, **kwargs): + """Activation Operator of Neural Net. + The parameters listed below can be passed in as keyword arguments. + + Parameters + ---------- + name : string, required. + Name of the resulting symbol. + + Returns + ------- + symbol: Symbol + the resulting symbol + """ + param_keys = [] + param_vals = [] + symbol_kwargs = {} + name = kwargs.pop('name', None) + + for k, v in kwargs.items(): + if isinstance(v, Symbol): + symbol_kwargs[k] = v + else: + param_keys.append(c_str(k)) + param_vals.append(c_str(str(v))) + # create atomic symbol + param_keys = c_array(ctypes.c_char_p, param_keys) + param_vals = c_array(ctypes.c_char_p, param_vals) + sym_handle = SymbolHandle() + check_call(_LIB.MXSymbolCreateAtomicSymbol( + handle, len(param_keys), + param_keys, param_vals, + ctypes.byref(sym_handle))) + + if len(args) != 0 and len(symbol_kwargs) != 0: + raise TypeError('%s can only accept input \ + Symbols either as positional or keyword arguments, not both' % func_name) + + s = Symbol(sym_handle) + s._compose(*args, name=name, **symbol_kwargs) + return s + creator.__name__ = func_name + return creator + + +def _init_module_functions(): + """List and add all the atomic symbol functions to current module.""" + plist = ctypes.POINTER(ctypes.c_void_p)() + size = ctypes.c_uint() + check_call(_LIB.MXSymbolListAtomicSymbolCreators(ctypes.byref(size), + ctypes.byref(plist))) + module_obj = sys.modules[__name__] + for i in range(size.value): + hdl = ctypes.c_void_p(plist[i]) + name = ctypes.c_char_p() + check_call(_LIB.MXSymbolGetAtomicSymbolName(hdl, ctypes.byref(name))) + function = _make_atomic_symbol_function(hdl, name.value) + setattr(module_obj, function.__name__, function) + +# Initialize the atomic symbo in startups +_init_module_functions() + diff --git a/python/mxnet/symbol_creator.py b/python/mxnet/symbol_creator.py deleted file mode 100644 index bcadbe7daacb..000000000000 --- a/python/mxnet/symbol_creator.py +++ /dev/null @@ -1,132 +0,0 @@ -# coding: utf-8 -# pylint: disable=invalid-name, protected-access, no-self-use -"""Symbol support of mxnet""" -from __future__ import absolute_import - -import ctypes -from .base import _LIB -from .base import c_array, c_str, string_types -from .base import SymbolHandle -from .base import check_call -from .symbol import Symbol - -class _SymbolCreator(object): - """SymbolCreator is a function that takes Param and return symbol""" - - def __init__(self, name, handle): - """Initialize the function with handle - - Parameters - ---------- - handle : SymbolCreatorHandle - the function handle of the function - - name : string - the name of the function - """ - self.name = name - self.handle = handle - - def __call__(self, *args, **kwargs): - """Invoke creator of symbol by passing kwargs - - Parameters - ---------- - name : string - Name of the resulting symbol. - - *args - Positional arguments - - **kwargs - Provide the params necessary for the symbol creation. - - Returns - ------- - the resulting symbol - """ - param_keys = [] - param_vals = [] - symbol_kwargs = {} - name = kwargs.pop('name', None) - - for k, v in kwargs.items(): - if isinstance(v, Symbol): - symbol_kwargs[k] = v - else: - param_keys.append(c_str(k)) - param_vals.append(c_str(str(v))) - - # create atomic symbol - param_keys = c_array(ctypes.c_char_p, param_keys) - param_vals = c_array(ctypes.c_char_p, param_vals) - sym_handle = SymbolHandle() - check_call(_LIB.MXSymbolCreateAtomicSymbol( - self.handle, len(param_keys), - param_keys, param_vals, - ctypes.byref(sym_handle))) - - if len(args) != 0 and len(symbol_kwargs) != 0: - raise TypeError('%s can only accept input \ - Symbols either as positional or keyword arguments, not both' % self.name) - - s = Symbol(sym_handle) - s._compose(*args, name=name, **symbol_kwargs) - return s - -class _SymbolCreatorRegistry(object): - """Function Registry""" - def __init__(self): - plist = ctypes.POINTER(ctypes.c_void_p)() - size = ctypes.c_uint() - check_call(_LIB.MXSymbolListAtomicSymbolCreators(ctypes.byref(size), - ctypes.byref(plist))) - hmap = {} - for i in range(size.value): - hdl = ctypes.c_void_p(plist[i]) - name = ctypes.c_char_p() - check_call(_LIB.MXSymbolGetAtomicSymbolName(hdl, ctypes.byref(name))) - hmap[name.value] = _SymbolCreator(name, hdl) - self.__dict__.update(hmap) - - def Variable(self, name): - """Create a symbolic variable with specified name. - - Parameters - ---------- - name : str - Name of the variable. - - Returns - ------- - variable : Symbol - The created variable symbol. - """ - if not isinstance(name, string_types): - raise TypeError('Expect a string for variable `name`') - handle = SymbolHandle() - check_call(_LIB.MXSymbolCreateVariable(name, ctypes.byref(handle))) - return Symbol(handle) - - def Group(self, symbols): - """Create a symbolic variable that groups several symbols together. - - Parameters - ---------- - symbols : list - List of symbols to be grouped. - - Returns - ------- - sym : Symbol - The created group symbol. - """ - ihandles = [] - for sym in symbols: - if not isinstance(sym, Symbol): - raise TypeError('Expect Symbols in the list input') - ihandles.append(sym.handle) - handle = SymbolHandle() - check_call(_LIB.MXSymbolCreateGroup( - len(ihandles), c_array(SymbolHandle, ihandles), ctypes.byref(handle))) - return Symbol(handle) diff --git a/python/test_mnist.py b/python/test_mnist.py index f9f37d2e82e3..9b61654f8897 100644 --- a/python/test_mnist.py +++ b/python/test_mnist.py @@ -63,10 +63,10 @@ def Get(self): # symbol net batch_size = 100 -data = mx.sym.Variable('data') -fc1 = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=160) -act1 = mx.sym.Activation(data = fc1, name='relu1', type="relu") -fc2 = mx.sym.FullyConnected(data = act1, name='fc2', num_hidden=10) +data = mx.symbol.Variable('data') +fc1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=160) +act1 = mx.symbol.Activation(data = fc1, name='relu1', type="relu") +fc2 = mx.symbol.FullyConnected(data = act1, name='fc2', num_hidden=10) args_list = fc2.list_arguments() # infer shape data_shape = (batch_size, 784)