-
-
Notifications
You must be signed in to change notification settings - Fork 389
/
optimizer.py
170 lines (146 loc) · 5.85 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from typing import Callable, Dict, List # isort:skip
import logging
import safitty
import torch
from catalyst.dl.core import Callback, CallbackOrder, RunnerState
from catalyst.dl.registry import GRAD_CLIPPERS
from catalyst.dl.utils import get_optimizer_momentum, maybe_recursive_call
from catalyst.dl.utils.torch import _Optimizer
logger = logging.getLogger(__name__)
class OptimizerCallback(Callback):
"""
Optimizer callback, abstraction over optimizer step.
"""
def __init__(
self,
grad_clip_params: Dict = None,
accumulation_steps: int = 1,
optimizer_key: str = None,
loss_key: str = "loss",
decouple_weight_decay: bool = True
):
"""
Args:
grad_clip_params (dict): params for gradient clipping
accumulation_steps (int): number of steps before
``model.zero_grad()``
optimizer_key (str): A key to take a optimizer in case
there are several of them and they are in a dictionary format.
loss_key (str): key to get loss from ``state.loss``
decouple_weight_decay (bool): If True - decouple weight decay
regularization.
"""
super().__init__(CallbackOrder.Optimizer)
grad_clip_params: dict = grad_clip_params or {}
self.grad_clip_fn = GRAD_CLIPPERS.get_from_params(**grad_clip_params)
self.accumulation_steps: int = accumulation_steps
self.optimizer_key: str = optimizer_key
self.loss_key: str = loss_key
self.decouple_weight_decay = decouple_weight_decay
self._optimizer_wd: List[float] = [0.0]
self._accumulation_counter: int = 0
@staticmethod
def grad_step(
*,
optimizer: _Optimizer,
optimizer_wds: List[float] = 0,
grad_clip_fn: Callable = None
):
for group, wd in zip(optimizer.param_groups, optimizer_wds):
if wd > 0:
for param in group["params"]:
param.data = param.data.add(
-wd * group["lr"], param.data
)
if grad_clip_fn is not None:
grad_clip_fn(group["params"])
optimizer.step()
def on_stage_start(self, state: RunnerState):
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)
assert optimizer is not None
lr = optimizer.defaults["lr"]
momentum = get_optimizer_momentum(optimizer)
state.set_key(lr, "lr", inner_key=self.optimizer_key)
state.set_key(momentum, "momentum", inner_key=self.optimizer_key)
def on_epoch_start(self, state):
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)
if self.decouple_weight_decay:
self._optimizer_wd = [
group.get("weight_decay", 0.0)
for group in optimizer.param_groups
]
for i in range(len(optimizer.param_groups)):
safitty.set(
optimizer.param_groups, i, "weight_decay", value=0.0)
else:
self._optimizer_wd = [0.0] * len(optimizer.param_groups)
def _get_loss(self, state) -> torch.Tensor:
loss = state.get_key(key="loss", inner_key=self.loss_key)
if isinstance(loss, list):
raise ValueError(
f"Loss is a list. "
f"Only the last value will be used for `backward`."
f"To aggregate losses into "
"one value use `CriterionAggregatorCallback`"
)
if isinstance(loss, dict):
error = f"Loss is a dict: {list(loss.keys())}, " \
f"to aggregate losses into " \
"one value use `CriterionAggregatorCallback`."
if self.loss_key is None:
error = error + " Or try to pass `loss_key` " \
"in the OptimizerCallback init"
raise ValueError(error)
return loss
def on_batch_start(self, state):
state.loss = None
def on_batch_end(self, state):
if not state.need_backward:
return
loss = self._get_loss(state)
self._accumulation_counter += 1
model = state.model
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)
need_gradient_step = \
(self._accumulation_counter + 1) % self.accumulation_steps == 0
# This is very hacky check whether we have AMP optimizer and this may
# change in future.
# But alternative solution is to have AmpOptimizerCallback.
# or expose another c'tor argument.
if hasattr(optimizer, "_amp_stash"):
from apex import amp
# Need to set ``delay_unscale``
# according to
# https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations
delay_unscale = not need_gradient_step
with amp.scale_loss(
loss,
optimizer,
delay_unscale=delay_unscale
) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
if need_gradient_step:
self.grad_step(
optimizer=optimizer,
optimizer_wds=self._optimizer_wd,
grad_clip_fn=self.grad_clip_fn
)
maybe_recursive_call(model, "zero_grad")
self._accumulation_counter = 0
def on_epoch_end(self, state):
if self.decouple_weight_decay:
optimizer = state.get_key(
key="optimizer", inner_key=self.optimizer_key
)
for i, wd in enumerate(self._optimizer_wd):
safitty.set(
optimizer.param_groups, i, "weight_decay", value=wd)
__all__ = ["OptimizerCallback"]