Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-535] Fix bugs in LR Schedulers and add warmup #11234

Merged
merged 3 commits into from
Aug 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion example/image-classification/common/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def _get_lr_scheduler(args, kv):
steps = [epoch_size * (x - begin_epoch)
for x in step_epochs if x - begin_epoch > 0]
if steps:
return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor))
return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor,
base_lr=args.lr))
else:
return (lr, None)

Expand Down
147 changes: 129 additions & 18 deletions python/mxnet/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

"""Scheduling learning rate."""
import logging
from math import cos, pi

class LRScheduler(object):
"""Base class of a learning rate scheduler.
Expand All @@ -28,9 +29,41 @@ class LRScheduler(object):
----------
base_lr : float, optional
The initial learning rate.
warmup_steps: int
number of warmup steps used before this scheduler starts decay
warmup_begin_lr: float
if using warmup, the learning rate from which it starts warming up
warmup_mode: string
warmup can be done in two modes.
'linear' mode gradually increases lr with each step in equal increments
'constant' mode keeps lr at warmup_begin_lr for warmup_steps
"""
def __init__(self, base_lr=0.01):
def __init__(self, base_lr=0.01,
warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):
self.base_lr = base_lr
assert isinstance(warmup_steps, int)
self.warmup_steps = warmup_steps

self.warmup_final_lr = base_lr
self.warmup_begin_lr = warmup_begin_lr
if self.warmup_begin_lr > self.warmup_final_lr:
raise ValueError("Base lr has to be higher than warmup_begin_lr")
if self.warmup_steps < 0:
raise ValueError("Warmup steps has to be positive or 0")
if warmup_mode not in ['linear', 'constant']:
raise ValueError("Supports only linear and constant modes of warmup")
self.warmup_mode = warmup_mode

def get_warmup_lr(self, num_update):
assert num_update < self.warmup_steps
if self.warmup_mode == 'linear':
increase = (self.warmup_final_lr - self.warmup_begin_lr) \
* float(num_update) / float(self.warmup_steps)
return self.warmup_begin_lr + increase
elif self.warmup_mode == 'constant':
return self.warmup_begin_lr
else:
raise ValueError("Invalid warmup mode %s"%self.warmup_mode)

def __call__(self, num_update):
"""Return a new learning rate.
Expand Down Expand Up @@ -66,8 +99,9 @@ class FactorScheduler(LRScheduler):
stop_factor_lr : float, optional
Stop updating the learning rate if it is less than this value.
"""
def __init__(self, step, factor=1, stop_factor_lr=1e-8):
super(FactorScheduler, self).__init__()
def __init__(self, step, factor=1, stop_factor_lr=1e-8, base_lr=0.01,
warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):
super(FactorScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode)
if step < 1:
raise ValueError("Schedule step must be greater or equal than 1 round")
if factor > 1.0:
Expand All @@ -78,6 +112,9 @@ def __init__(self, step, factor=1, stop_factor_lr=1e-8):
self.count = 0

def __call__(self, num_update):
if num_update < self.warmup_steps:
return self.get_warmup_lr(num_update)

# NOTE: use while rather than if (for continuing training via load_epoch)
while num_update > self.count + self.step:
self.count += self.step
Expand Down Expand Up @@ -108,9 +145,19 @@ class MultiFactorScheduler(LRScheduler):
The list of steps to schedule a change
factor: float
The factor to change the learning rate.
warmup_steps: int
number of warmup steps used before this scheduler starts decay
warmup_begin_lr: float
if using warmup, the learning rate from which it starts warming up
warmup_mode: string
warmup can be done in two modes.
'linear' mode gradually increases lr with each step in equal increments
'constant' mode keeps lr at warmup_begin_lr for warmup_steps
"""
def __init__(self, step, factor=1):
super(MultiFactorScheduler, self).__init__()
def __init__(self, step, factor=1, base_lr=0.01, warmup_steps=0, warmup_begin_lr=0,
warmup_mode='linear'):
super(MultiFactorScheduler, self).__init__(base_lr, warmup_steps,
warmup_begin_lr, warmup_mode)
assert isinstance(step, list) and len(step) >= 1
for i, _step in enumerate(step):
if i != 0 and step[i] <= step[i-1]:
Expand All @@ -125,6 +172,9 @@ def __init__(self, step, factor=1):
self.count = 0

def __call__(self, num_update):
if num_update < self.warmup_steps:
return self.get_warmup_lr(num_update)

# NOTE: use while rather than if (for continuing training via load_epoch)
while self.cur_step_ind <= len(self.step)-1:
if num_update > self.step[self.cur_step_ind]:
Expand All @@ -138,33 +188,94 @@ def __call__(self, num_update):
return self.base_lr

class PolyScheduler(LRScheduler):
""" Reduce the learning rate by given a list of steps.
""" Reduce the learning rate according to a polynomial of given power.

Calculate the new learning rate by::
Calculate the new learning rate, after warmup if any, by::

base_lr * (1-nup/max_nup)^pwr
final_lr + (start_lr - final_lr) * (1-nup/max_nup)^pwr
if nup < max_nup, 0 otherwise.

Parameters
----------
max_update: maximum number of updates before the decay reaches 0.
base_lr: base learning rate
pwr: power of the decay term as a funtion of the current number of updates.

max_update: int
maximum number of updates before the decay reaches final learning rate.
base_lr: float
base learning rate to start from
pwr: int
power of the decay term as a function of the current number of updates.
final_lr: float
final learning rate after all steps
warmup_steps: int
number of warmup steps used before this scheduler starts decay
warmup_begin_lr: float
if using warmup, the learning rate from which it starts warming up
warmup_mode: string
warmup can be done in two modes.
'linear' mode gradually increases lr with each step in equal increments
'constant' mode keeps lr at warmup_begin_lr for warmup_steps
"""

def __init__(self, max_update, base_lr=0.01, pwr=2):
super(PolyScheduler, self).__init__(base_lr)
def __init__(self, max_update, base_lr=0.01, pwr=2, final_lr=0,
warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):
super(PolyScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode)
assert isinstance(max_update, int)
if max_update < 1:
raise ValueError("maximum number of updates must be strictly positive")
self.power = pwr
self.base_lr_orig = self.base_lr
self.max_update = max_update
self.power = pwr
self.base_lr = self.base_lr_orig
self.final_lr = final_lr
self.max_steps = self.max_update - self.warmup_steps

def __call__(self, num_update):
if num_update < self.warmup_steps:
return self.get_warmup_lr(num_update)
if num_update <= self.max_update:
self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * \
pow(1 - float(num_update - self.warmup_steps) / float(self.max_steps), self.power)
return self.base_lr

class CosineScheduler(LRScheduler):
""" Reduce the learning rate according to a cosine function

Calculate the new learning rate by::

final_lr + (start_lr - final_lr) * (1+cos(pi * nup/max_nup))/2
if nup < max_nup, 0 otherwise.

Parameters
----------
max_update: int
maximum number of updates before the decay reaches 0
base_lr: float
base learning rate
final_lr: float
final learning rate after all steps
warmup_steps: int
number of warmup steps used before this scheduler starts decay
warmup_begin_lr: float
if using warmup, the learning rate from which it starts warming up
warmup_mode: string
warmup can be done in two modes.
'linear' mode gradually increases lr with each step in equal increments
'constant' mode keeps lr at warmup_begin_lr for warmup_steps
"""

def __init__(self, max_update, base_lr=0.01, final_lr=0,
warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):
super(CosineScheduler, self).__init__(base_lr, warmup_steps, warmup_begin_lr, warmup_mode)
assert isinstance(max_update, int)
if max_update < 1:
raise ValueError("maximum number of updates must be strictly positive")
self.base_lr_orig = base_lr
self.max_update = max_update
self.final_lr = final_lr
self.max_steps = self.max_update - self.warmup_steps

def __call__(self, num_update):
if num_update < self.warmup_steps:
return self.get_warmup_lr(num_update)
if num_update <= self.max_update:
self.base_lr = self.base_lr_orig * pow(1.0 - float(num_update) / float(self.max_update),
self.power)
self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * \
(1 + cos(pi * (num_update - self.warmup_steps) / self.max_steps)) / 2
return self.base_lr
52 changes: 52 additions & 0 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,58 @@ def test_adagrad():
w_stype='row_sparse', g_stype='row_sparse')


def test_factor_scheduler():
base_lr = 1
step = 100
factor = 0.1
sched = mx.lr_scheduler.FactorScheduler(step, factor, stop_factor_lr=1e-4, base_lr=base_lr,
warmup_steps=20, warmup_begin_lr=0.1, warmup_mode='constant')

assert (sched(0) == 0.1)
np.testing.assert_almost_equal(sched(10), 0.1)
assert (sched(21) == base_lr), sched(21)
np.testing.assert_almost_equal(sched(101), base_lr * factor)
np.testing.assert_almost_equal(sched(201), base_lr * factor * factor)
np.testing.assert_almost_equal(sched(1000), 1e-4)

def test_multifactor_scheduler():
base_lr = 0.1
steps = [15, 25]
factor = 0.1
sched = mx.lr_scheduler.MultiFactorScheduler(steps, factor, base_lr=base_lr,
warmup_steps=10, warmup_begin_lr=0.05, warmup_mode='linear')

assert sched(0) == 0.05
np.testing.assert_almost_equal(sched(5), 0.05 + (base_lr - 0.05)/2)
np.testing.assert_almost_equal(sched(15), base_lr)
np.testing.assert_almost_equal(sched(16), base_lr * factor)
np.testing.assert_almost_equal(sched(20), base_lr * factor)
np.testing.assert_almost_equal(sched(26), base_lr * factor * factor)
np.testing.assert_almost_equal(sched(100), base_lr * factor * factor)

def test_poly_scheduler():
base_lr = 3
final_lr = 0
steps = 1000
poly_sched = mx.lr_scheduler.PolyScheduler(steps, base_lr=base_lr, pwr=2, final_lr=final_lr,
warmup_steps=100, warmup_begin_lr=0, warmup_mode='linear')

np.testing.assert_almost_equal(poly_sched(0), 0)
np.testing.assert_almost_equal(poly_sched(50), float(base_lr)/2)
np.testing.assert_almost_equal(poly_sched(100), base_lr)
assert (poly_sched(101) < poly_sched(100))
assert (poly_sched(500) < 1.6)
np.testing.assert_almost_equal(poly_sched(steps), final_lr)

def test_cosine_scheduler():
# also tests case without warmup
base_lr = 3
final_lr = 0.1
steps = 1000
cosine_sched = mx.lr_scheduler.CosineScheduler(steps, base_lr=base_lr, final_lr=final_lr)
np.testing.assert_almost_equal(cosine_sched(0), base_lr)
np.testing.assert_almost_equal(cosine_sched(steps), final_lr)
assert (cosine_sched(500) > 1.5)

if __name__ == '__main__':
import nose
Expand Down