Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

simplify symbol creator as discussed #22

Merged
merged 2 commits into from
Aug 21, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# MXNet

[![Build Status](https://travis-ci.org/dmlc/mxnet.svg?branch=master)](https://travis-ci.org/dmlc/mxnet)
[![Documentation Status](https://readthedocs.org/projects/mxnet/badge/?version=latest)](https://readthedocs.org/projects/mxnet/?badge=latest)

This is a project that combines lessons and ideas we learnt from [cxxnet](https://github.com/dmlc/cxxnet), [minerva](https://github.com/dmlc/minerva) and [purine2](https://github.com/purine/purine2).
- The interface is designed in collaboration by authors of three projects.
Expand Down
5 changes: 2 additions & 3 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
from .context import Context, current_context
from .narray import NArray
from .function import _FunctionRegistry
from .symbol import Symbol
from .symbol_creator import _SymbolCreatorRegistry
from . import symbol

__version__ = "0.1.0"

# this is a global function registry that can be used to invoke functions
op = NArray._init_function_registry(_FunctionRegistry())
sym = Symbol._init_symbol_creator_registry(_SymbolCreatorRegistry())

130 changes: 113 additions & 17 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,16 @@
from __future__ import absolute_import

import ctypes
import sys
from .base import _LIB
from .base import c_array, c_str, mx_uint, NArrayHandle, ExecutorHandle, SymbolHandle
from .base import c_array, c_str, mx_uint, string_types
from .base import NArrayHandle, ExecutorHandle, SymbolHandle
from .base import check_call
from .context import Context
from .executor import Executor

class Symbol(object):
"""Symbol is symbolic graph of the mxnet."""
_registry = None

@staticmethod
def _init_symbol_creator_registry(symbol_creator_registry):
"""Initialize symbol creator registry

Parameters
----------
symbol_creator_registry:
pass in symbol_creator_registry
Returns
-------
the passed in registry
"""
_registry = symbol_creator_registry
return _registry

def __init__(self, handle):
"""Initialize the function with handle
Expand Down Expand Up @@ -257,3 +243,113 @@ def bind(self, ctx, args, args_grad, reqs):
reqs_array,
ctypes.byref(handle)))
return Executor(handle)


def Variable(name):
"""Create a symbolic variable with specified name.

Parameters
----------
name : str
Name of the variable.

Returns
-------
variable : Symbol
The created variable symbol.
"""
if not isinstance(name, string_types):
raise TypeError('Expect a string for variable `name`')
handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateVariable(name, ctypes.byref(handle)))
return Symbol(handle)


def Group(symbols):
"""Create a symbolic variable that groups several symbols together.

Parameters
----------
symbols : list
List of symbols to be grouped.

Returns
-------
sym : Symbol
The created group symbol.
"""
ihandles = []
for sym in symbols:
if not isinstance(sym, Symbol):
raise TypeError('Expect Symbols in the list input')
ihandles.append(sym.handle)
handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateGroup(
len(ihandles), c_array(SymbolHandle, ihandles), ctypes.byref(handle)))
return Symbol(handle)


def _make_atomic_symbol_function(handle, func_name):
"""Create an atomic symbol function by handle and funciton name."""
def creator(*args, **kwargs):
"""Activation Operator of Neural Net.
The parameters listed below can be passed in as keyword arguments.

Parameters
----------
name : string, required.
Name of the resulting symbol.

Returns
-------
symbol: Symbol
the resulting symbol
"""
param_keys = []
param_vals = []
symbol_kwargs = {}
name = kwargs.pop('name', None)

for k, v in kwargs.items():
if isinstance(v, Symbol):
symbol_kwargs[k] = v
else:
param_keys.append(c_str(k))
param_vals.append(c_str(str(v)))
# create atomic symbol
param_keys = c_array(ctypes.c_char_p, param_keys)
param_vals = c_array(ctypes.c_char_p, param_vals)
sym_handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateAtomicSymbol(
handle, len(param_keys),
param_keys, param_vals,
ctypes.byref(sym_handle)))

if len(args) != 0 and len(symbol_kwargs) != 0:
raise TypeError('%s can only accept input \
Symbols either as positional or keyword arguments, not both' % func_name)

s = Symbol(sym_handle)
s._compose(*args, name=name, **symbol_kwargs)
return s
creator.__name__ = func_name
return creator


def _init_module_functions():
"""List and add all the atomic symbol functions to current module."""
plist = ctypes.POINTER(ctypes.c_void_p)()
size = ctypes.c_uint()
check_call(_LIB.MXSymbolListAtomicSymbolCreators(ctypes.byref(size),
ctypes.byref(plist)))
module_obj = sys.modules[__name__]
for i in range(size.value):
hdl = ctypes.c_void_p(plist[i])
name = ctypes.c_char_p()
check_call(_LIB.MXSymbolGetAtomicSymbolName(hdl, ctypes.byref(name)))
function = _make_atomic_symbol_function(hdl, name.value)
setattr(module_obj, function.__name__, function)

# Initialize the atomic symbo in startups
_init_module_functions()

132 changes: 0 additions & 132 deletions python/mxnet/symbol_creator.py

This file was deleted.

8 changes: 4 additions & 4 deletions python/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def Get(self):

# symbol net
batch_size = 100
data = mx.sym.Variable('data')
fc1 = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=160)
act1 = mx.sym.Activation(data = fc1, name='relu1', type="relu")
fc2 = mx.sym.FullyConnected(data = act1, name='fc2', num_hidden=10)
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=160)
act1 = mx.symbol.Activation(data = fc1, name='relu1', type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name='fc2', num_hidden=10)
args_list = fc2.list_arguments()
# infer shape
data_shape = (batch_size, 784)
Expand Down