diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index c3a1f3374a94..edbcfaddefb5 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -63,8 +63,11 @@ class Optimizer(object): clip_gradient : float, optional, default None Clip the gradient by projecting onto the box ``[-clip_gradient, clip_gradient]``. - learning_rate : float, optional, default 0.01 - The initial learning rate. + learning_rate : float, optional, default None + The initial learning rate. If None, the optimization will use the + learning rate from ``lr_scheduler``. If not None, it will overwrite + the learning rate in ``lr_scheduler``. If None and ``lr_scheduler`` + is also None, then it will be set to 0.01 by default. lr_scheduler : LRScheduler, optional, default None The learning rate scheduler. @@ -97,14 +100,19 @@ class Optimizer(object): optimizer, its learning rate can be accessed as optimizer.learning_rate. """ def __init__(self, rescale_grad=1., param_idx2name=None, wd=0., - clip_gradient=None, learning_rate=0.01, + clip_gradient=None, learning_rate=None, lr_scheduler=None, sym=None, begin_num_update=0, multi_precision=False, param_dict=None): self.rescale_grad = rescale_grad - self.lr = learning_rate self.lr_scheduler = lr_scheduler - if lr_scheduler is not None: - self.lr_scheduler.base_lr = learning_rate + if self.lr_scheduler is None and learning_rate is None: + learning_rate = 0.01 + self.lr = learning_rate + if self.lr_scheduler is not None and learning_rate is not None: + if self.lr_scheduler.base_lr != learning_rate: + print(UserWarning("learning rate from ``lr_scheduler`` has been " + "overwritten by ``learning_rate`` in optimizer.")) + self.lr_scheduler.base_lr = learning_rate self.wd = wd self.lr_mult = {} diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index cea469960f64..f9adf63b24e2 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -39,6 +39,10 @@ def test_learning_rate(): o2.lr_scheduler.base_lr = 0.4 assert o2.learning_rate == 0.4 + lr_s = lr_scheduler.FactorScheduler(step=1, base_lr=1024) + o3 = mx.optimizer.Optimizer(lr_scheduler=lr_s) + assert o3.learning_rate == 1024 + @raises(UserWarning) @with_seed()