Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #118 from tqchen/master
Browse files Browse the repository at this point in the history
add name manager
  • Loading branch information
mli committed Sep 22, 2015
2 parents 99d5925 + 10739e0 commit 02ca41b
Show file tree
Hide file tree
Showing 11 changed files with 171 additions and 44 deletions.
1 change: 1 addition & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .base import MXNetError
from . import base
from . import ndarray
from . import name
from . import symbol
from . import kvstore as kv
from . import io
Expand Down
37 changes: 25 additions & 12 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,36 @@
from __future__ import absolute_import

class Context(object):
"""Context representing device and device id in mxnet"""
"""Constructing a context.
Parameters
----------
device_type : {'cpu', 'gpu'} or Context.
String representing the device type
device_id : int (default=0)
The device id of the device, needed for GPU
Note
----
Context can also be used a way to change default context.
Examples
--------
Switch default context example:
>>> # array on cpu
>>> cpu_array = mx.md.ones((2, 3))
>>> # switch default context to GPU(2)
>>> with mx.Context(mx.gpu(2)):
>>> gpu_array = mx.md.ones((2, 3))
>>> gpu_array.context
Context(device_type=gpu, device_id=2)
"""
# static class variable
default_ctx = None
devtype2str = {1: 'cpu', 2: 'gpu'}
devstr2type = {'cpu': 1, 'gpu': 2}

def __init__(self, device_type, device_id=0):
"""Constructing a context.
Parameters
----------
device_type : str (can be 'cpu' or 'gpu')
a string representing the device type
device_id : int (default=0)
the device id of the device, needed for GPU
"""
if isinstance(device_type, Context):
self.device_typeid = device_type.device_typeid
self.device_id = device_type.device_id
Expand Down
31 changes: 17 additions & 14 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# pylint: skip-file
# coding: utf-8
"""Initialization helper for mxnet"""
from __future__ import absolute_import

import numpy as np
from .base import string_types
from .ndarray import NDArray
Expand Down Expand Up @@ -36,17 +39,17 @@ def __call__(self, name, arr):
self._init_zero(name, arr)
else:
self._init_default(name, arr)

def _init_zero(self, name, arr):
# pylint: disable=no-self-use, missing-docstring
def _init_zero(self, _, arr):
arr[:] = 0.0

def _init_bias(self, name, arr):
def _init_bias(self, _, arr):
arr[:] = 0.0

def _init_gamma(self, name, arr):
def _init_gamma(self, _, arr):
arr[:] = 1.0

def _init_beta(self, name, arr):
def _init_beta(self, _, arr):
arr[:] = 0.0

def _init_weight(self, name, arr):
Expand All @@ -55,7 +58,7 @@ def _init_weight(self, name, arr):

def _init_default(self, name, _):
raise ValueError('Unknown initialization pattern for %s' % name)

# pylint: enable=no-self-use, missing-docstring

class Uniform(Initializer):
"""Initialize the weight with uniform [-scale, scale]
Expand All @@ -68,8 +71,8 @@ class Uniform(Initializer):
def __init__(self, scale=0.07):
self.scale = scale

def _init_weight(self, name, arr):
random.uniform(-scale, scale, out=arr)
def _init_weight(self, _, arr):
random.uniform(-self.scale, self.scale, out=arr)


class Normal(Initializer):
Expand All @@ -81,10 +84,10 @@ class Normal(Initializer):
Standard deviation for gaussian distribution.
"""
def __init__(self, sigma=0.01):
super().__init__(sigma = sigma)
self.sigma = sigma

def _init_weight(self, name, arr):
random.normal(0, sigma, out=arr)
def _init_weight(self, _, arr):
random.normal(0, self.sigma, out=arr)


class Xavier(Initializer):
Expand All @@ -95,6 +98,6 @@ def _init_weight(self, _, arr):
# [in, out] for fullc
shape = arr.shape
fan_in, fan_out = shape[1], shape[0]
s = np.sqrt(6. / (fan_in + fan_out))
random.uniform(-s, s, out=arr)
scale = np.sqrt(6. / (fan_in + fan_out))
random.uniform(-scale, scale, out=arr)

1 change: 0 additions & 1 deletion python/mxnet/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# coding: utf-8

"""NDArray interface of mxnet"""
from __future__ import absolute_import

Expand Down
1 change: 1 addition & 0 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint: disable=invalid-name, global-statement
""" KVStore in mxnet """
from __future__ import absolute_import

import ctypes
from .ndarray import NDArray
from .base import _LIB
Expand Down
22 changes: 16 additions & 6 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals
# pylint: disable=too-many-branches, too-many-statements, unused-argument
"""MXNet model module"""
from __future__ import absolute_import

import numpy as np
import time
import logging
Expand Down Expand Up @@ -201,11 +203,18 @@ def _train_multi_device(symbol, ctx, input_shape,
aux_params[name].copyto(w)
# ky value store
kv = kvstore.create() if num_device != 1 else None
opt_state_blocks = []
# If there are multiple devices, initialize the weights.
for index, pair in enumerate(zip(arg_blocks, grad_blocks)):
arg, grad = pair
if kv and grad[0] is not None:
kv.init(index, arg[0])
arg_list, grad_list = pair
if kv and grad_list[0] is not None:
kv.init(index, arg_list[0])
# attach state direct to weight
opt_list = [optimizer.create_state(index, w) for w in arg_list]
opt_state_blocks.append(opt_list)
else:
opt_state_blocks.append(None)

# Input and output data structure
data_index, label_index = _check_arguments(symbol)
merged_shape = list(train_execs[0].outputs[0].shape)
Expand Down Expand Up @@ -244,9 +253,10 @@ def _train_multi_device(symbol, ctx, input_shape,
kv.push(index, grad_list)
# pull back the sum, to the same locations.
kv.pull(index, grad_list)
# optimize
for w, g in zip(arg_list, grad_list):
optimizer.update(index, w, g)
opt_list = opt_state_blocks[index]
# optimizea
for w, g, state in zip(arg_list, grad_list, opt_list):
optimizer.update(index, w, g, state)
# evaluate at end, so out_cpu_array can lazy copy
eval_metric.update(out_cpu_array, label)

Expand Down
78 changes: 78 additions & 0 deletions python/mxnet/name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# coding: utf-8
"""Automatic naming support for symbolic API."""
from __future__ import absolute_import

class NameManager(object):
"""NameManager to do automatic naming.
User can also inheritate this object to change naming behavior.
"""
current = None

def __init__(self):
self._counter = {}
self._old_manager = None

def get(self, name, hint):
"""Get the canonical name for a symbol.
This is default implementation.
When user specified a name,
the user specified name will be used.
When user did not, we will automatically generate a
name based on hint string.
Parameters
----------
name : str or None
The name user specified.
hint : str
A hint string, which can be used to generate name.
Returns
-------
full_name : str
A canonical name for the user.
"""
if name:
return name
if hint not in self._counter:
self._counter[hint] = 0
name = '%s%d' % (hint, self._counter[hint])
self._counter[hint] += 1
return name

def __enter__(self):
self._old_manager = NameManager.current
NameManager.current = self
return self

def __exit__(self, ptype, value, trace):
assert self._old_manager
NameManager.current = self._old_manager


class Prefix(NameManager):
"""A name manager that always attach a prefix to all names.
Examples
--------
>>> import mxnet as mx
>>> data = mx.symbol.Variable('data')
>>> with mx.name.Prefix('mynet_'):
net = mx.symbol.FullyConnected(data, num_hidden=10, name='fc1')
>>> net.list_arguments()
['data', 'mynet_fc1_weight', 'mynet_fc1_bias']
"""
def __init__(self, prefix):
super(Prefix, self).__init__()
self._prefix = prefix

def get(self, name, hint):
name = super(Prefix, self).get(name, hint)
return self._prefix + name

# initialize the default name manager
NameManager.current = NameManager()
36 changes: 26 additions & 10 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: disable=fixme, invalid-name
# pylint: disable=fixme, invalid-name, unused-argument
"""Common Optimization algorithms with regularizations."""
from .ndarray import NDArray, zeros

Expand Down Expand Up @@ -31,7 +31,6 @@ class SGD(Optimizer):
rescale_grad : float, optional
rescaling factor of gradient.
"""
def __init__(self, learning_rate=0.01, momentum=0.0,
wd=0.0001, rescale_grad=1):
Expand All @@ -41,7 +40,21 @@ def __init__(self, learning_rate=0.01, momentum=0.0,
self.rescale_grad = rescale_grad
self.momentums = {}

def update(self, index, weight, grad):
def create_state(self, index, weight):
"""Create additional optimizer state such as momentum.
Parameters
----------
weight : NDArray
The weight data
"""
if self.momentum == 0.0:
return None
else:
return zeros(weight.shape, weight.context)

def update(self, index, weight, grad, state):
"""Update the parameters.
Parameters
Expand All @@ -55,17 +68,20 @@ def update(self, index, weight, grad):
grad : NDArray
grad ndarray
state : NDArray or other objects returned by init_state
The auxiliary state used in optimization.
"""
# TODO(bing) implement wd_bias, wd_gamma, wd_beta
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))

if index not in self.momentums:
self.momentums[index] = zeros(grad.shape, grad.context)
mom = self.momentums[index]
mom[:] *= self.momentum
mom[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight)
weight[:] += mom
if state:
mom = state
mom[:] *= self.momentum
mom[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight)
weight[:] += mom
else:
assert self.momentum == 0.0
weight[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight)


def create(name, rescale_grad=1, **kwargs):
Expand Down
4 changes: 4 additions & 0 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .base import c_array, c_str, mx_uint, py_str, string_types
from .base import NDArrayHandle, ExecutorHandle, SymbolHandle
from .base import check_call, ctypes2docstring
from .name import NameManager
from .context import Context
from .ndarray import NDArray, zeros
from .executor import Executor
Expand Down Expand Up @@ -128,6 +129,7 @@ def _compose(self, *args, **kwargs):
the resulting symbol
"""
name = kwargs.pop('name', None)

if name:
name = c_str(name)
if len(args) != 0 and len(kwargs) != 0:
Expand Down Expand Up @@ -752,6 +754,8 @@ def creator(*args, **kwargs):
' instead of keyword arguments.')

s = Symbol(sym_handle)
hint = func_name.lower()
name = NameManager.current.get(name, hint)
s._compose(*args, name=name, **symbol_kwargs)
return s

Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# pylint: disable=invalid-name, protected-access, too-many-locals, fixme
# pylint: disable=unused-argument, too-many-branches, too-many-statements
"""Visualization module"""
from __future__ import absolute_import

from .symbol import Symbol
import json
import re
Expand Down
2 changes: 1 addition & 1 deletion tests/python/train/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

num_round = 4
prefix = './mlp'
model = mx.model.FeedForward(softmax, mx.cpu(),
model = mx.model.FeedForward(softmax, [mx.cpu()] * 2,
num_round=num_round,
learning_rate=0.01, wd=0.0004,
momentum=0.9)
Expand Down

0 comments on commit 02ca41b

Please sign in to comment.