Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix optimizer under multi device
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Sep 22, 2015
1 parent 33e1cbb commit 10739e0
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 32 deletions.
31 changes: 17 additions & 14 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# pylint: skip-file
# coding: utf-8
"""Initialization helper for mxnet"""
from __future__ import absolute_import

import numpy as np
from .base import string_types
from .ndarray import NDArray
Expand Down Expand Up @@ -36,17 +39,17 @@ def __call__(self, name, arr):
self._init_zero(name, arr)
else:
self._init_default(name, arr)

def _init_zero(self, name, arr):
# pylint: disable=no-self-use, missing-docstring
def _init_zero(self, _, arr):
arr[:] = 0.0

def _init_bias(self, name, arr):
def _init_bias(self, _, arr):
arr[:] = 0.0

def _init_gamma(self, name, arr):
def _init_gamma(self, _, arr):
arr[:] = 1.0

def _init_beta(self, name, arr):
def _init_beta(self, _, arr):
arr[:] = 0.0

def _init_weight(self, name, arr):
Expand All @@ -55,7 +58,7 @@ def _init_weight(self, name, arr):

def _init_default(self, name, _):
raise ValueError('Unknown initialization pattern for %s' % name)

# pylint: enable=no-self-use, missing-docstring

class Uniform(Initializer):
"""Initialize the weight with uniform [-scale, scale]
Expand All @@ -68,8 +71,8 @@ class Uniform(Initializer):
def __init__(self, scale=0.07):
self.scale = scale

def _init_weight(self, name, arr):
random.uniform(-scale, scale, out=arr)
def _init_weight(self, _, arr):
random.uniform(-self.scale, self.scale, out=arr)


class Normal(Initializer):
Expand All @@ -81,10 +84,10 @@ class Normal(Initializer):
Standard deviation for gaussian distribution.
"""
def __init__(self, sigma=0.01):
super().__init__(sigma = sigma)
self.sigma = sigma

def _init_weight(self, name, arr):
random.normal(0, sigma, out=arr)
def _init_weight(self, _, arr):
random.normal(0, self.sigma, out=arr)


class Xavier(Initializer):
Expand All @@ -95,6 +98,6 @@ def _init_weight(self, _, arr):
# [in, out] for fullc
shape = arr.shape
fan_in, fan_out = shape[1], shape[0]
s = np.sqrt(6. / (fan_in + fan_out))
random.uniform(-s, s, out=arr)
scale = np.sqrt(6. / (fan_in + fan_out))
random.uniform(-scale, scale, out=arr)

1 change: 0 additions & 1 deletion python/mxnet/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# coding: utf-8

"""NDArray interface of mxnet"""
from __future__ import absolute_import

Expand Down
1 change: 1 addition & 0 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pylint: disable=invalid-name, global-statement
""" KVStore in mxnet """
from __future__ import absolute_import

import ctypes
from .ndarray import NDArray
from .base import _LIB
Expand Down
22 changes: 16 additions & 6 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals
# pylint: disable=too-many-branches, too-many-statements, unused-argument
"""MXNet model module"""
from __future__ import absolute_import

import numpy as np
import time
import logging
Expand Down Expand Up @@ -201,11 +203,18 @@ def _train_multi_device(symbol, ctx, input_shape,
aux_params[name].copyto(w)
# ky value store
kv = kvstore.create() if num_device != 1 else None
opt_state_blocks = []
# If there are multiple devices, initialize the weights.
for index, pair in enumerate(zip(arg_blocks, grad_blocks)):
arg, grad = pair
if kv and grad[0] is not None:
kv.init(index, arg[0])
arg_list, grad_list = pair
if kv and grad_list[0] is not None:
kv.init(index, arg_list[0])
# attach state direct to weight
opt_list = [optimizer.create_state(index, w) for w in arg_list]
opt_state_blocks.append(opt_list)
else:
opt_state_blocks.append(None)

# Input and output data structure
data_index, label_index = _check_arguments(symbol)
merged_shape = list(train_execs[0].outputs[0].shape)
Expand Down Expand Up @@ -244,9 +253,10 @@ def _train_multi_device(symbol, ctx, input_shape,
kv.push(index, grad_list)
# pull back the sum, to the same locations.
kv.pull(index, grad_list)
# optimize
for w, g in zip(arg_list, grad_list):
optimizer.update(index, w, g)
opt_list = opt_state_blocks[index]
# optimizea
for w, g, state in zip(arg_list, grad_list, opt_list):
optimizer.update(index, w, g, state)
# evaluate at end, so out_cpu_array can lazy copy
eval_metric.update(out_cpu_array, label)

Expand Down
1 change: 1 addition & 0 deletions python/mxnet/name.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding: utf-8
"""Automatic naming support for symbolic API."""
from __future__ import absolute_import

class NameManager(object):
"""NameManager to do automatic naming.
Expand Down
36 changes: 26 additions & 10 deletions python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# pylint: disable=fixme, invalid-name
# pylint: disable=fixme, invalid-name, unused-argument
"""Common Optimization algorithms with regularizations."""
from .ndarray import NDArray, zeros

Expand Down Expand Up @@ -31,7 +31,6 @@ class SGD(Optimizer):
rescale_grad : float, optional
rescaling factor of gradient.
"""
def __init__(self, learning_rate=0.01, momentum=0.0,
wd=0.0001, rescale_grad=1):
Expand All @@ -41,7 +40,21 @@ def __init__(self, learning_rate=0.01, momentum=0.0,
self.rescale_grad = rescale_grad
self.momentums = {}

def update(self, index, weight, grad):
def create_state(self, index, weight):
"""Create additional optimizer state such as momentum.
Parameters
----------
weight : NDArray
The weight data
"""
if self.momentum == 0.0:
return None
else:
return zeros(weight.shape, weight.context)

def update(self, index, weight, grad, state):
"""Update the parameters.
Parameters
Expand All @@ -55,17 +68,20 @@ def update(self, index, weight, grad):
grad : NDArray
grad ndarray
state : NDArray or other objects returned by init_state
The auxiliary state used in optimization.
"""
# TODO(bing) implement wd_bias, wd_gamma, wd_beta
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))

if index not in self.momentums:
self.momentums[index] = zeros(grad.shape, grad.context)
mom = self.momentums[index]
mom[:] *= self.momentum
mom[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight)
weight[:] += mom
if state:
mom = state
mom[:] *= self.momentum
mom[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight)
weight[:] += mom
else:
assert self.momentum == 0.0
weight[:] += -self.lr * (grad * self.rescale_grad + self.wd * weight)


def create(name, rescale_grad=1, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# pylint: disable=invalid-name, protected-access, too-many-locals, fixme
# pylint: disable=unused-argument, too-many-branches, too-many-statements
"""Visualization module"""
from __future__ import absolute_import

from .symbol import Symbol
import json
import re
Expand Down
2 changes: 1 addition & 1 deletion tests/python/train/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

num_round = 4
prefix = './mlp'
model = mx.model.FeedForward(softmax, mx.cpu(),
model = mx.model.FeedForward(softmax, [mx.cpu()] * 2,
num_round=num_round,
learning_rate=0.01, wd=0.0004,
momentum=0.9)
Expand Down

0 comments on commit 10739e0

Please sign in to comment.