Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
50ca388
Add Linear warmup+decay lr schedule
tjruwase Sep 16, 2020
f236ba5
Merge branch 'master' of github.com:microsoft/DeepSpeed into olruwase…
tjruwase Sep 18, 2020
288fc96
LR scheduler unit tests for LR Range Test and 1Cycle
tjruwase Sep 21, 2020
596e9a6
Merge branch 'master' into olruwase/lr_warmup_decay
tjruwase Sep 21, 2020
62d3c91
Disable yapf to preserve parameterizaton
tjruwase Sep 22, 2020
997d5d5
Merge branch 'olruwase/lr_warmup_decay' of github.com:microsoft/DeepS…
tjruwase Sep 22, 2020
4026bf5
Merge branch 'master' into olruwase/lr_warmup_decay
tjruwase Sep 22, 2020
4d343b5
Merge branch 'master' into olruwase/lr_warmup_decay
tjruwase Sep 25, 2020
c4b36e2
Merge branch 'master' into olruwase/lr_warmup_decay
tjruwase Sep 27, 2020
55487b2
Merge branch 'master' into olruwase/lr_warmup_decay
tjruwase Sep 28, 2020
5c2b37d
Disable test_pipe.py for CI debugging
tjruwase Sep 28, 2020
f0d0e21
Disable test_lr_scheduler for CI debugging
tjruwase Sep 28, 2020
3a1c9cb
Disable test_lr_scheduler for CI debugging
tjruwase Sep 28, 2020
bac425f
Enable all unit tests for CI debugging
tjruwase Sep 28, 2020
54e689f
Merge branch 'master' into olruwase/lr_warmup_decay
tjruwase Sep 30, 2020
560ef1c
Merge branch 'master' into olruwase/lr_warmup_decay
tjruwase Oct 6, 2020
f73ad56
Merge branch 'master' into olruwase/lr_warmup_decay
tjruwase Oct 7, 2020
e98bad0
Merge branch 'master' into olruwase/lr_warmup_decay
tjruwase Oct 9, 2020
d3e97cf
Merge branch 'master' into olruwase/lr_warmup_decay
jeffra Dec 2, 2020
b395994
Merge branch 'master' into olruwase/lr_warmup_decay
jeffra Jan 8, 2021
b268a16
Merge branch 'master' of github.com:microsoft/DeepSpeed into olruwase…
tjruwase Jan 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 52 additions & 34 deletions deepspeed/runtime/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,10 @@ def __init__(self,
self._update_optimizer(self.min_lr)

def _staircase_interval(self):
return math.floor(float(self.last_batch_iteration) / self.step_size)
return math.floor(float(self.last_batch_iteration + 1) / self.step_size)

def _continous_interval(self):
return float(self.last_batch_iteration) / self.step_size
return float(self.last_batch_iteration + 1) / self.step_size

def _get_increase(self):
return (1 + self.step_rate * self.interval_fn())
Expand Down Expand Up @@ -574,66 +574,73 @@ def _initialize_momentum(self,
for momentum, group in zip(self.min_moms, optimizer.param_groups):
group['betas'] = momentum

def _get_cycle_lr(self):
cycle = math.floor(1 + self.last_batch_iteration / self.total_size)
x = 1. + self.last_batch_iteration / self.total_size - cycle
def _get_scale_factor(self):
batch_iteration = (self.last_batch_iteration + 1)
cycle = math.floor(1 + batch_iteration / self.total_size)
x = 1. + batch_iteration / self.total_size - cycle
if x <= self.step_ratio:
scale_factor = x / self.step_ratio
else:
scale_factor = (x - 1) / (self.step_ratio - 1)

return scale_factor

def _get_cycle_mom(self):
scale_factor = self._get_scale_factor()
momentums = []
for base_betas, max_betas in zip(self.min_moms, self.max_moms):
cycle_min_mom = base_betas[0]
cycle_max_mom = max_betas[0]
base_height = (cycle_max_mom - cycle_min_mom) * scale_factor
momentum = cycle_max_mom - base_height
momentums.append((momentum, base_betas[1]))
return momentums

def _get_cycle_lr(self):
scale_factor = self._get_scale_factor()
lrs = []
for cycle_min_lr, cycle_max_lr in zip(self.min_lrs, self.max_lrs):
base_height = (cycle_max_lr - cycle_min_lr) * scale_factor
lr = cycle_min_lr + base_height
lrs.append(lr)

if self.cycle_momentum:
momentums = []
for base_betas, max_betas in zip(self.min_moms, self.max_moms):
cycle_min_mom = base_betas[0]
cycle_max_mom = max_betas[0]
base_height = (cycle_max_mom - cycle_min_mom) * scale_factor
momentum = cycle_max_mom - base_height
momentums.append((momentum, base_betas[1]))
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
param_group['betas'] = momentum

return lrs

def _get_decay_mom(self, decay_batch_iteration):
decay_interval = decay_batch_iteration / self.decay_step_size
mom_decay_factor = (1 + self.decay_mom_rate * decay_interval)
momentums = [(beta0 * mom_decay_factor, beta1) for beta0, beta1 in self.max_moms]
return momentums

def _get_decay_lr(self, decay_batch_iteration):
"""Calculates the learning rate at batch index. This function is used
after the cycle completes and post cycle decaying of lr/mom is enabled.
This function treats `self.last_batch_iteration` as the last batch index.

If `self.cycle_momentum` is ``True``, this function has a side effect of
updating the optimizer's momentum.
"""
decay_interval = decay_batch_iteration / self.decay_step_size

lr_decay_factor = (1 + self.decay_lr_rate * decay_interval)
lrs = [cycle_min_lr * lr_decay_factor for cycle_min_lr in self.min_lrs]

if self.cycle_momentum:
mom_decay_factor = (1 + self.decay_mom_rate * decay_interval)
momentums = [(beta0 * mom_decay_factor,
beta1) for beta0,
beta1 in self.max_moms]
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
param_group['betas'] = momentum
lrs = [cycle_min_lr / lr_decay_factor for cycle_min_lr in self.min_lrs]

return lrs

def get_lr(self):
"""Calculates the learning rate at batch index. This function treats
`self.last_batch_iteration` as the last batch index.

If `self.cycle_momentum` is ``True``, this function has a side effect of
updating the optimizer's momentum.
"""
if self.last_batch_iteration <= self.total_size:
if self.last_batch_iteration < self.total_size:
return self._get_cycle_lr()
return self._get_decay_lr(self.last_batch_iteration - self.total_size)
return self._get_decay_lr(self.last_batch_iteration - self.total_size + 1)

def get_mom(self):
"""Calculates the momentum at batch index. This function treats
`self.last_batch_iteration` as the last batch index.
"""
if not self.cycle_momentum:
return None

if self.last_batch_iteration < self.total_size:
return self._get_cycle_mom()
return self._get_decay_mom(self.last_batch_iteration - self.total_size + 1)

def get_last_lr(self):
""" Return last computed learning rate by current scheduler.
Expand All @@ -642,13 +649,24 @@ def get_last_lr(self):
return self._last_lr

def step(self, batch_iteration=None):
""" Updates the optimizer with the learning rate for the last batch index.
`self.last_batch_iteration` is treated as the last batch index.

If self.cycle_momentum is true, also updates optimizer momentum.
"""
if batch_iteration is None:
batch_iteration = self.last_batch_iteration + 1

self.last_batch_iteration = batch_iteration
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

if self.cycle_momentum:
momentums = self.get_mom()
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
param_group['betas'] = momentum

def state_dict(self):
return {'last_batch_iteration': self.last_batch_iteration}

Expand Down
Loading