-
Notifications
You must be signed in to change notification settings - Fork 14
/
scheduler.py
24 lines (21 loc) · 947 Bytes
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import numpy as np
class MipLRDecay(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, lr_init, lr_final, max_steps, lr_delay_steps=0, lr_delay_mult=1):
self.lr_init = lr_init
self.lr_final = lr_final
self.max_steps = max_steps
self.lr_delay_steps = lr_delay_steps
self.lr_delay_mult = lr_delay_mult
super(MipLRDecay, self).__init__(optimizer)
def get_lr(self):
step = self.last_epoch
if self.lr_delay_steps > 0:
# A kind of reverse cosine decay.
delay_rate = self.lr_delay_mult + (1 - self.lr_delay_mult) * np.sin(
0.5 * np.pi * np.clip(self.last_epoch / self.lr_delay_steps, 0, 1))
else:
delay_rate = 1.
t = np.clip(step / self.max_steps, 0, 1)
log_lerp = np.exp(np.log(self.lr_init) * (1 - t) + np.log(self.lr_final) * t)
return [delay_rate * log_lerp]