18
18
from __future__ import annotations
19
19
20
20
from dataclasses import dataclass
21
- from typing import Any , Dict , List , Type
21
+ from typing import Any , Dict , List , Optional , Type
22
22
23
23
import torch
24
24
from torch .cuda .amp .grad_scaler import GradScaler
@@ -36,13 +36,15 @@ class OptimizerConfig(base_config.PrintableConfig):
36
36
_target : Type = torch .optim .Adam
37
37
lr : float = 0.0005
38
38
eps : float = 1e-08
39
+ max_norm : Optional [float ] = None
39
40
40
41
# TODO: somehow make this more generic. i dont like the idea of overriding the setup function
41
42
# but also not sure how to go about passing things into predefined torch objects.
42
43
def setup (self , params ) -> Any :
43
44
"""Returns the instantiated object using the config."""
44
45
kwargs = vars (self ).copy ()
45
46
kwargs .pop ("_target" )
47
+ kwargs .pop ("max_norm" )
46
48
return self ._target (params , ** kwargs )
47
49
48
50
@@ -73,9 +75,11 @@ def __init__(self, config: Dict[str, Any], param_groups: Dict[str, List[Paramete
73
75
self .config = config
74
76
self .optimizers = {}
75
77
self .schedulers = {}
78
+ self .parameters = {}
76
79
for param_group_name , params in param_groups .items ():
77
80
lr_init = config [param_group_name ]["optimizer" ].lr
78
81
self .optimizers [param_group_name ] = config [param_group_name ]["optimizer" ].setup (params = params )
82
+ self .parameters [param_group_name ] = params
79
83
if config [param_group_name ]["scheduler" ]:
80
84
self .schedulers [param_group_name ] = config [param_group_name ]["scheduler" ].setup (
81
85
optimizer = self .optimizers [param_group_name ], lr_init = lr_init
@@ -109,13 +113,20 @@ def optimizer_scaler_step_all(self, grad_scaler: GradScaler) -> None:
109
113
Args:
110
114
grad_scaler: GradScaler to use
111
115
"""
112
- for _ , optimizer in self .optimizers .items ():
116
+ for param_group , optimizer in self .optimizers .items ():
117
+ max_norm = self .config [param_group ]["optimizer" ].max_norm
118
+ if max_norm is not None :
119
+ grad_scaler .unscale_ (optimizer )
120
+ torch .nn .utils .clip_grad_norm_ (self .parameters [param_group ], max_norm )
113
121
grad_scaler .step (optimizer )
114
122
115
123
def optimizer_step_all (self ):
116
124
"""Run step for all optimizers."""
117
- for _ , optimizer in self .optimizers .items ():
125
+ for param_group , optimizer in self .optimizers .items ():
118
126
# note that they key is the parameter name
127
+ max_norm = self .config [param_group ]["optimizer" ].max_norm
128
+ if max_norm is not None :
129
+ torch .nn .utils .clip_grad_norm_ (self .parameters [param_group ], max_norm )
119
130
optimizer .step ()
120
131
121
132
def scheduler_step_all (self , step : int ) -> None :
0 commit comments