Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Correct update count with Gluon trainer and update_on_kvstore=False #14377

Merged
merged 4 commits into from
Mar 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,6 @@ def _init_kvstore(self):
kvstore.set_optimizer(self._optimizer)
self._kvstore = kvstore
self._update_on_kvstore = update_on_kvstore
if self._optimizer.lr_scheduler and not self._update_on_kvstore:
raise ValueError("update_on_kvstore=False does not support " \
"optimizer with LRScheduler. Please " \
"consider setting learning rate manually.")
else:
self._kvstore = None
self._update_on_kvstore = None
Expand Down
17 changes: 16 additions & 1 deletion python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def __init__(self, rescale_grad=1., param_idx2name=None, wd=0.,
self.wd_mult = {}
self.begin_num_update = begin_num_update
self.num_update = begin_num_update
self._index_update_count = {}
self._all_index_update_counts = {0 : {}}
self._index_update_count = self._all_index_update_counts[0]
self.clip_gradient = clip_gradient
self.multi_precision = multi_precision
self.aggregate_num = 0
Expand Down Expand Up @@ -380,6 +381,18 @@ def set_wd_mult(self, args_wd_mult):
self.wd_mult[name] = float(attr[name]['__wd_mult__'])
self.wd_mult.update(args_wd_mult)

def _set_current_context(self, device_id):
"""Sets the number of the currently handled device.

Parameters
----------
device_id : int
The number of current device.
"""
if device_id not in self._all_index_update_counts:
self._all_index_update_counts[device_id] = {}
self._index_update_count = self._all_index_update_counts[device_id]

def _update_count(self, index):
"""Updates num_update.

Expand Down Expand Up @@ -1623,6 +1636,8 @@ def __call__(self, index, grad, weight):
indices = index
grads = grad
weights = weight
if weights:
self.optimizer._set_current_context(weights[0].context.device_id)
for i, idx in enumerate(indices):
# convert ctypes.char_p.value back to python str if needed
if isinstance(idx, bytes):
Expand Down
21 changes: 12 additions & 9 deletions tests/python/unittest/test_gluon_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,19 +272,22 @@ def test_trainer_lr_sched():
lr *= factor
mx.nd.waitall()

@with_seed()
def test_trainer_invalid_lr_sched():
# Update on kvstore = False
x = gluon.Parameter('x', shape=(10,))
x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
freq = 2
factor = 0.1
lr = 1
lr_sched = mx.lr_scheduler.FactorScheduler(freq, factor=factor, base_lr=lr)
invalid_trainer = gluon.Trainer([x], 'sgd', {'learning_rate': lr, 'lr_scheduler': lr_sched},
update_on_kvstore=False)
with mx.autograd.record():
for w in x.list_data():
y = w + 1
y.backward()
assert_raises(ValueError, invalid_trainer.step, 1)
trainer = gluon.Trainer([x], 'sgd', {'learning_rate': lr, 'lr_scheduler': lr_sched},
update_on_kvstore=False)
for i in range(10):
with mx.autograd.record():
for w in x.list_data():
y = w + 1
y.backward()
trainer.step(1)
if i % freq == 0:
assert trainer.learning_rate == lr, (lr, trainer.learning_rate, i)
lr *= factor
mx.nd.waitall()