diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 8060f38ac2aa..45a44d8eb3e4 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -241,10 +241,6 @@ def _init_kvstore(self): kvstore.set_optimizer(self._optimizer) self._kvstore = kvstore self._update_on_kvstore = update_on_kvstore - if self._optimizer.lr_scheduler and not self._update_on_kvstore: - raise ValueError("update_on_kvstore=False does not support " \ - "optimizer with LRScheduler. Please " \ - "consider setting learning rate manually.") else: self._kvstore = None self._update_on_kvstore = None diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index def2c958ede4..2e7fe86c5af9 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -106,7 +106,8 @@ def __init__(self, rescale_grad=1., param_idx2name=None, wd=0., self.wd_mult = {} self.begin_num_update = begin_num_update self.num_update = begin_num_update - self._index_update_count = {} + self._all_index_update_counts = {0 : {}} + self._index_update_count = self._all_index_update_counts[0] self.clip_gradient = clip_gradient self.multi_precision = multi_precision self.aggregate_num = 0 @@ -380,6 +381,18 @@ def set_wd_mult(self, args_wd_mult): self.wd_mult[name] = float(attr[name]['__wd_mult__']) self.wd_mult.update(args_wd_mult) + def _set_current_context(self, device_id): + """Sets the number of the currently handled device. + + Parameters + ---------- + device_id : int + The number of current device. + """ + if device_id not in self._all_index_update_counts: + self._all_index_update_counts[device_id] = {} + self._index_update_count = self._all_index_update_counts[device_id] + def _update_count(self, index): """Updates num_update. @@ -1623,6 +1636,8 @@ def __call__(self, index, grad, weight): indices = index grads = grad weights = weight + if weights: + self.optimizer._set_current_context(weights[0].context.device_id) for i, idx in enumerate(indices): # convert ctypes.char_p.value back to python str if needed if isinstance(idx, bytes): diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index 9f190a0a88c2..2d5874a8b97b 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -272,19 +272,22 @@ def test_trainer_lr_sched(): lr *= factor mx.nd.waitall() -@with_seed() -def test_trainer_invalid_lr_sched(): + # Update on kvstore = False x = gluon.Parameter('x', shape=(10,)) x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') freq = 2 factor = 0.1 lr = 1 lr_sched = mx.lr_scheduler.FactorScheduler(freq, factor=factor, base_lr=lr) - invalid_trainer = gluon.Trainer([x], 'sgd', {'learning_rate': lr, 'lr_scheduler': lr_sched}, - update_on_kvstore=False) - with mx.autograd.record(): - for w in x.list_data(): - y = w + 1 - y.backward() - assert_raises(ValueError, invalid_trainer.step, 1) + trainer = gluon.Trainer([x], 'sgd', {'learning_rate': lr, 'lr_scheduler': lr_sched}, + update_on_kvstore=False) + for i in range(10): + with mx.autograd.record(): + for w in x.list_data(): + y = w + 1 + y.backward() + trainer.step(1) + if i % freq == 0: + assert trainer.learning_rate == lr, (lr, trainer.learning_rate, i) + lr *= factor mx.nd.waitall()