Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

add name manager #118

Merged
merged 2 commits into from
Sep 22, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .base import MXNetError
from . import base
from . import ndarray
from . import name
from . import symbol
from . import kvstore as kv
from . import io
Expand Down
37 changes: 25 additions & 12 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,36 @@
from __future__ import absolute_import

class Context(object):
"""Context representing device and device id in mxnet"""
"""Constructing a context.

Parameters
----------
device_type : {'cpu', 'gpu'} or Context.
String representing the device type

device_id : int (default=0)
The device id of the device, needed for GPU

Note
----
Context can also be used a way to change default context.

Examples
--------
Switch default context example:
>>> # array on cpu
>>> cpu_array = mx.md.ones((2, 3))
>>> # switch default context to GPU(2)
>>> with mx.Context(mx.gpu(2)):
>>> gpu_array = mx.md.ones((2, 3))
>>> gpu_array.context
Context(device_type=gpu, device_id=2)
"""
# static class variable
default_ctx = None
devtype2str = {1: 'cpu', 2: 'gpu'}
devstr2type = {'cpu': 1, 'gpu': 2}

def __init__(self, device_type, device_id=0):
"""Constructing a context.

Parameters
----------
device_type : str (can be 'cpu' or 'gpu')
a string representing the device type

device_id : int (default=0)
the device id of the device, needed for GPU
"""
if isinstance(device_type, Context):
self.device_typeid = device_type.device_typeid
self.device_id = device_type.device_id
Expand Down
31 changes: 17 additions & 14 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# pylint: skip-file
# coding: utf-8
"""Initialization helper for mxnet"""
from __future__ import absolute_import

import numpy as np
from .base import string_types
from .ndarray import NDArray
Expand Down Expand Up @@ -36,17 +39,17 @@ def __call__(self, name, arr):
self._init_zero(name, arr)
else:
self._init_default(name, arr)

def _init_zero(self, name, arr):
# pylint: disable=no-self-use, missing-docstring
def _init_zero(self, _, arr):
arr[:] = 0.0

def _init_bias(self, name, arr):
def _init_bias(self, _, arr):
arr[:] = 0.0

def _init_gamma(self, name, arr):
def _init_gamma(self, _, arr):
arr[:] = 1.0

def _init_beta(self, name, arr):
def _init_beta(self, _, arr):
arr[:] = 0.0

def _init_weight(self, name, arr):
Expand All @@ -55,7 +58,7 @@ def _init_weight(self, name, arr):

def _init_default(self, name, _):
raise ValueError('Unknown initialization pattern for %s' % name)

# pylint: enable=no-self-use, missing-docstring

class Uniform(Initializer):
"""Initialize the weight with uniform [-scale, scale]
Expand All @@ -68,8 +71,8 @@ class Uniform(Initializer):
def __init__(self, scale=0.07):
self.scale = scale

def _init_weight(self, name, arr):
random.uniform(-scale, scale, out=arr)
def _init_weight(self, _, arr):
random.uniform(-self.scale, self.scale, out=arr)


class Normal(Initializer):
Expand All @@ -81,10 +84,10 @@ class Normal(Initializer):
Standard deviation for gaussian distribution.
"""
def __init__(self, sigma=0.01):
super().__init__(sigma = sigma)
self.sigma = sigma

def _init_weight(self, name, arr):
random.normal(0, sigma, out=arr)
def _init_weight(self, _, arr):
random.normal(0, self.sigma, out=arr)


class Xavier(Initializer):
Expand All @@ -95,6 +98,6 @@ def _init_weight(self, _, arr):
# [in, out] for fullc
shape = arr.shape
fan_in, fan_out = shape[1], shape[0]
s = np.sqrt(6. / (fan_in + fan_out))
random.uniform(-s, s, out=arr)
scale = np.sqrt(6. / (fan_in + fan_out))
random.uniform(-scale, scale, out=arr)

1 change: 0 additions & 1 deletion python/mxnet/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# coding: utf-8

"""NDArray interface of mxnet"""
from __future__ import absolute_import

Expand Down
1 change: 1 addition & 0 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint: disable=invalid-name, global-statement
""" KVStore in mxnet """
from __future__ import absolute_import

import ctypes
from .ndarray import NDArray
from .base import _LIB
Expand Down
22 changes: 16 additions & 6 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals
# pylint: disable=too-many-branches, too-many-statements, unused-argument
"""MXNet model module"""
from __future__ import absolute_import

import numpy as np
import time
import logging
Expand Down Expand Up @@ -201,11 +203,18 @@ def _train_multi_device(symbol, ctx, input_shape,
aux_params[name].copyto(w)
# ky value store
kv = kvstore.create() if num_device != 1 else None
opt_state_blocks = []
# If there are multiple devices, initialize the weights.
for index, pair in enumerate(zip(arg_blocks, grad_blocks)):
arg, grad = pair
if kv and grad[0] is not None:
kv.init(index, arg[0])
arg_list, grad_list = pair
if kv and grad_list[0] is not None:
kv.init(index, arg_list[0])
# attach state direct to weight
opt_list = [optimizer.create_state(index, w) for w in arg_list]
opt_state_blocks.append(opt_list)
else:
opt_state_blocks.append(None)

# Input and output data structure
data_index, label_index = _check_arguments(symbol)
merged_shape = list(train_execs[0].outputs[0].shape)
Expand Down Expand Up @@ -244,9 +253,10 @@ def _train_multi_device(symbol, ctx, input_shape,
kv.push(index, grad_list)
# pull back the sum, to the same locations.
kv.pull(index, grad_list)
# optimize
for w, g in zip(arg_list, grad_list):
optimizer.update(index, w, g)
opt_list = opt_state_blocks[index]
# optimizea
for w, g, state in zip(arg_list, grad_list, opt_list):
optimizer.update(index, w, g, state)
# evaluate at end, so out_cpu_array can lazy copy
eval_metric.update(out_cpu_array, label)

Expand Down
78 changes: 78 additions & 0 deletions python/mxnet/name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# coding: utf-8
"""Automatic naming support for symbolic API."""
from __future__ import absolute_import

class NameManager(object):
"""NameManager to do automatic naming.

User can also inheritate this object to change naming behavior.
"""
current = None

def __init__(self):
self._counter = {}
self._old_manager = None

def get(self, name, hint):
"""Get the canonical name for a symbol.

This is default implementation.
When user specified a name,
the user specified name will be used.

When user did not, we will automatically generate a
name based on hint string.

Parameters
----------
name : str or None
The name user specified.

hint : str
A hint string, which can be used to generate name.

Returns
-------
full_name : str
A canonical name for the user.
"""
if name:
return name
if hint not in self._counter:
self._counter[hint] = 0
name = '%s%d' % (hint, self._counter[hint])
self._counter[hint] += 1
return name

def __enter__(self):
self._old_manager = NameManager.current
NameManager.current = self
return self

def __exit__(self, ptype, value, trace):
assert self._old_manager
NameManager.current = self._old_manager


class Prefix(NameManager):
"""A name manager that always attach a prefix to all names.

Examples
--------
>>> import mxnet as mx
>>> data = mx.symbol.Variable('data')
>>> with mx.name.Prefix('mynet_'):
net = mx.symbol.FullyConnected(data, num_hidden=10, name='fc1')
>>> net.list_arguments()
['data', 'mynet_fc1_weight', 'mynet_fc1_bias']
"""
def __init__(self, prefix):
super(Prefix, self).__init__()
self._prefix = prefix

def get(self, name, hint):
name = super(Prefix, self).get(name, hint)
return self._prefix + name

# initialize the default name manager
NameManager.current = NameManager()
36 changes: 26 additions & 10 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: disable=fixme, invalid-name
# pylint: disable=fixme, invalid-name, unused-argument
"""Common Optimization algorithms with regularizations."""
from .ndarray import NDArray, zeros

Expand Down Expand Up @@ -31,7 +31,6 @@ class SGD(Optimizer):

rescale_grad : float, optional
rescaling factor of gradient.

"""
def __init__(self, learning_rate=0.01, momentum=0.0,
wd=0.0001, rescale_grad=1):
Expand All @@ -41,7 +40,21 @@ def __init__(self, learning_rate=0.01, momentum=0.0,
self.rescale_grad = rescale_grad
self.momentums = {}

def update(self, index, weight, grad):
def create_state(self, index, weight):
"""Create additional optimizer state such as momentum.

Parameters
----------
weight : NDArray
The weight data

"""
if self.momentum == 0.0:
return None
else:
return zeros(weight.shape, weight.context)

def update(self, index, weight, grad, state):
"""Update the parameters.

Parameters
Expand All @@ -55,17 +68,20 @@ def update(self, index, weight, grad):
grad : NDArray
grad ndarray

state : NDArray or other objects returned by init_state
The auxiliary state used in optimization.
"""
# TODO(bing) implement wd_bias, wd_gamma, wd_beta
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))

if index not in self.momentums:
self.momentums[index] = zeros(grad.shape, grad.context)
mom = self.momentums[index]
mom[:] *= self.momentum
mom[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight)
weight[:] += mom
if state:
mom = state
mom[:] *= self.momentum
mom[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight)
weight[:] += mom
else:
assert self.momentum == 0.0
weight[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight)


def create(name, rescale_grad=1, **kwargs):
Expand Down
4 changes: 4 additions & 0 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .base import c_array, c_str, mx_uint, py_str, string_types
from .base import NDArrayHandle, ExecutorHandle, SymbolHandle
from .base import check_call, ctypes2docstring
from .name import NameManager
from .context import Context
from .ndarray import NDArray, zeros
from .executor import Executor
Expand Down Expand Up @@ -128,6 +129,7 @@ def _compose(self, *args, **kwargs):
the resulting symbol
"""
name = kwargs.pop('name', None)

if name:
name = c_str(name)
if len(args) != 0 and len(kwargs) != 0:
Expand Down Expand Up @@ -752,6 +754,8 @@ def creator(*args, **kwargs):
' instead of keyword arguments.')

s = Symbol(sym_handle)
hint = func_name.lower()
name = NameManager.current.get(name, hint)
s._compose(*args, name=name, **symbol_kwargs)
return s

Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# pylint: disable=invalid-name, protected-access, too-many-locals, fixme
# pylint: disable=unused-argument, too-many-branches, too-many-statements
"""Visualization module"""
from __future__ import absolute_import

from .symbol import Symbol
import json
import re
Expand Down
2 changes: 1 addition & 1 deletion tests/python/train/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

num_round = 4
prefix = './mlp'
model = mx.model.FeedForward(softmax, mx.cpu(),
model = mx.model.FeedForward(softmax, [mx.cpu()] * 2,
num_round=num_round,
learning_rate=0.01, wd=0.0004,
momentum=0.9)
Expand Down