-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
Copy pathprecision_plugin.py
244 lines (210 loc) · 9.52 KB
/
precision_plugin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from functools import partial
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.core.hooks import CheckpointHooks
from pytorch_lightning.utilities import grad_norm, GradClipAlgorithmType
from pytorch_lightning.utilities.types import _PARAMETERS
class PrecisionPlugin(CheckpointHooks):
"""Base class for all plugins handling the precision-specific parts of the training.
The class attribute precision must be overwritten in child classes. The default value reflects fp32 training.
"""
precision: Union[str, int] = 32
def main_params(self, optimizer: Optimizer) -> _PARAMETERS:
"""The main params of the model.
Returns the plain model params here. Maybe different in other precision plugins.
"""
for group in optimizer.param_groups:
yield from group["params"]
def connect(
self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
) -> Tuple[Module, List[Optimizer], List[Any]]:
"""Connects this plugin to the accelerator and the training process."""
return model, optimizers, lr_schedulers
def pre_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor:
"""Run before precision plugin executes backward.
Args:
model: the model to be optimized
closure_loss: the loss value obtained from the closure
"""
model.trainer._call_callback_hooks("on_before_backward", closure_loss)
model.trainer._call_lightning_module_hook("on_before_backward", closure_loss)
return closure_loss
def backward(
self,
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
*args: Any,
**kwargs: Any,
) -> None:
"""Performs the actual backpropagation.
Args:
model: the model to be optimized
closure_loss: the loss value obtained from the closure
optimizer: current optimizer being used. ``None`` if using manual optimization
"""
# do backward pass
if model is not None and isinstance(model, pl.LightningModule):
model.backward(closure_loss, optimizer, *args, **kwargs)
else:
self._run_backward(closure_loss, *args, **kwargs)
def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor:
"""Run after precision plugin executes backward.
Args:
model: the model to be optimized
closure_loss: the loss value obtained from the closure
"""
# once backward has been applied, release graph
closure_loss = closure_loss.detach()
model.trainer._call_callback_hooks("on_after_backward")
model.trainer._call_lightning_module_hook("on_after_backward")
return closure_loss
def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
"""Lightning-independent backward logic.
Currently only used by Lightning Lite. Subject to further refactors.
"""
tensor.backward(*args, **kwargs)
def _after_closure(
self, model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int
) -> None:
"""Utility to share some code after the closure has been run."""
if not isinstance(model, pl.LightningModule):
# none of this applies to Lite
return
trainer = model.trainer
assert trainer is not None
trainer._call_callback_hooks("on_before_optimizer_step", optimizer, optimizer_idx)
trainer._call_lightning_module_hook("on_before_optimizer_step", optimizer, optimizer_idx)
# TODO: this is done for the entire model but should be changed to per-optimizer
if optimizer_idx == 0:
self._track_grad_norm(trainer)
self._clip_gradients(
model,
optimizer,
optimizer_idx,
trainer.gradient_clip_val,
gradient_clip_algorithm=trainer.gradient_clip_algorithm,
)
def _wrap_closure(
self,
model: "pl.LightningModule",
optimizer: Optimizer,
optimizer_idx: int,
closure: Callable[[], Any],
) -> Any:
"""This double-closure allows makes sure the ``closure`` is executed before the
``on_before_optimizer_step`` hook is called.
The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is
consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
"""
closure_result = closure()
self._after_closure(model, optimizer, optimizer_idx)
return closure_result
def optimizer_step(
self,
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
"""Hook to run the optimizer step."""
if isinstance(model, pl.LightningModule):
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
optimizer.step(closure=closure, **kwargs)
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if trainer.track_grad_norm == -1:
return
kwargs = {"group_separator": trainer.logger.group_separator} if trainer.logger is not None else {}
grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs)
if grad_norm_dict:
prev_fx = trainer.lightning_module._current_fx_name
trainer.lightning_module._current_fx_name = "on_before_optimizer_step"
trainer.lightning_module.log_grad_norm(grad_norm_dict)
trainer.lightning_module._current_fx_name = prev_fx
def _clip_gradients(
self,
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[GradClipAlgorithmType] = None,
) -> None:
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization:
# the configuration validator disallows clipping on manual
return
model.configure_gradient_clipping(
optimizer,
optimizer_idx,
gradient_clip_val=clip_val,
gradient_clip_algorithm=gradient_clip_algorithm,
)
def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""Clips the gradients."""
if clip_val <= 0:
return
if gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
self.clip_grad_by_value(optimizer, clip_val)
elif gradient_clip_algorithm == GradClipAlgorithmType.NORM:
self.clip_grad_by_norm(optimizer, clip_val)
def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by value."""
parameters = self.main_params(optimizer)
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)
def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by norm."""
parameters = self.main_params(optimizer)
torch.nn.utils.clip_grad_norm_(parameters, clip_val)
def dispatch(self, trainer: "pl.Trainer") -> None:
"""Hook to do something when ``Strategy.dispatch()`` gets called."""
@contextlib.contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
yield
@contextlib.contextmanager
def train_step_context(self) -> Generator[None, None, None]:
"""A contextmanager for the training step."""
with self.forward_context():
yield
@contextlib.contextmanager
def val_step_context(self) -> Generator[None, None, None]:
"""A contextmanager for the validation step."""
with self.forward_context():
yield
@contextlib.contextmanager
def test_step_context(self) -> Generator[None, None, None]:
"""A contextmanager for the test step."""
with self.forward_context():
yield
@contextlib.contextmanager
def predict_step_context(self) -> Generator[None, None, None]:
"""A contextmanager for the predict step."""
with self.forward_context():
yield
def teardown(self) -> None:
"""This method is called to teardown the training process.
It is the right place to release memory and free other resources.
"""