Skip to content

Commit

Permalink
Set idx2name for Optimizer object (apache#14703)
Browse files Browse the repository at this point in the history
* set idx2name for optimizer

* add unit test
  • Loading branch information
yuxihu authored and haohuw committed Jun 23, 2019
1 parent 9076741 commit 35f5364
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 7 deletions.
2 changes: 2 additions & 0 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,8 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc',
rescale_grad=(1.0/batch_size),
**(self.kwargs))
elif isinstance(self.optimizer, opt.Optimizer):
if not optimizer.idx2name:
optimizer.idx2name = param_idx2name.copy()
optimizer = self.optimizer

# do training
Expand Down
16 changes: 9 additions & 7 deletions python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,14 +505,14 @@ def init_optimizer(self, kvstore='local', optimizer='sgd',
batch_size *= kvstore.num_workers
rescale_grad = 1.0/batch_size

idx2name = {}
if update_on_kvstore:
idx2name.update(enumerate(self._exec_group.param_names))
else:
for k in range(len(self._context)):
idx2name.update({i*len(self._context)+k: n
for i, n in enumerate(self._exec_group.param_names)})
if isinstance(optimizer, str):
idx2name = {}
if update_on_kvstore:
idx2name.update(enumerate(self._exec_group.param_names))
else:
for k in range(len(self._context)):
idx2name.update({i*len(self._context)+k: n
for i, n in enumerate(self._exec_group.param_names)})
optimizer_params = dict(optimizer_params)
if 'rescale_grad' not in optimizer_params:
optimizer_params['rescale_grad'] = rescale_grad
Expand All @@ -528,6 +528,8 @@ def init_optimizer(self, kvstore='local', optimizer='sgd',
"is not normalized to 1.0/batch_size/num_workers (%s vs. %s). "%(
optimizer.rescale_grad, rescale_grad) +
"Is this intended?", stacklevel=2)
if not optimizer.idx2name:
optimizer.idx2name = idx2name.copy()

self._optimizer = optimizer
self._kvstore = kvstore
Expand Down
28 changes: 28 additions & 0 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,34 @@ def test_module_update_no_pragram():
mod.update()
assert(mod.get_outputs()[0].shape == data_shape)


def test_module_init_optimizer():
def get_module_idx2name(mod):
idx2name = {}
idx2name.update(enumerate(mod._exec_group.param_names))
return idx2name

data = mx.sym.Variable('data')
sym = mx.sym.FullyConnected(data, num_hidden=20, name='fc')
batch_size = 8
opt_params = {'learning_rate': 1, 'rescale_grad': 1.0 / batch_size}

# Pass an optimizer str
mod1 = mx.mod.Module(sym, ('data',), None, context=mx.cpu(0))
mod1.bind(data_shapes=[('data', (batch_size, 20))])
mod1.init_params()
mod1.init_optimizer(optimizer='sgd', optimizer_params=opt_params)
assert mod1._optimizer.idx2name == get_module_idx2name(mod1)

# Pass an Optimizer object
mod2 = mx.mod.Module(sym, ('data',), None, context=mx.cpu(0))
mod2.bind(data_shapes=[('data', (batch_size, 20))])
mod2.init_params()
opt = mx.optimizer.SGD(**opt_params)
mod2.init_optimizer(optimizer=opt)
assert mod2._optimizer.idx2name == get_module_idx2name(mod2)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 35f5364

Please sign in to comment.