11import argparse
22import os
33from dataclasses import dataclass , field
4- from typing import List
4+ from typing import List , Union
55
66# torch imports
77import torch
@@ -113,6 +113,8 @@ def main(args):
113113 input_ids = input_ids .cuda ()
114114 labels = labels .cuda ()
115115
116+ optimizer .zero_grad ()
117+
116118 # forward
117119 pred = model (input_ids )
118120 tok_loss = F .cross_entropy (
@@ -123,12 +125,14 @@ def main(args):
123125 # backward on scaled loss to create scaled gradients
124126 scaler .scale (loss ).backward ()
125127
128+ # clip gradients (after unscaling gradients of the optimizer's params)
129+ scaler .unscale_ (optimizer )
130+ model .clip_grad_norm_ (args .max_norm )
131+
126132 # optimizer step
127- # scaler.step() first unscales gradients of the optimizer's params.
128- # If gradients don't contain infs/NaNs, optimizer.step() is then called,
133+ # If gradients don't contain infs/NaNs, optimizer.step() is then called;
129134 # otherwise, optimizer.step() is skipped.
130135 scaler .step (optimizer )
131- optimizer .zero_grad ()
132136
133137 # updates the scale for next iteration
134138 scaler .update ()
@@ -168,6 +172,9 @@ def main(args):
168172 "--optimizer" , type = str , default = "AdamW" , help = "optimizer to use"
169173 )
170174 parser .add_argument ("--lr" , type = float , default = 2e-5 , help = "learning rate to use" )
175+ parser .add_argument (
176+ "--max_norm" , type = Union [float , int ], default = 1.0 , help = "max norm for gradient clipping"
177+ )
171178 parser .add_argument (
172179 "--steps" , type = int , default = - 1 , help = "how many train steps to run"
173180 )
0 commit comments