Skip to content

Commit 2bb25ad

Browse files
committed
[torchtrain] add gradient clipping
ghstack-source-id: c5c3fe8 Pull Request resolved: #28
1 parent 2cb1b72 commit 2bb25ad

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

train.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22
import os
33
from dataclasses import dataclass, field
4-
from typing import List
4+
from typing import List, Union
55

66
# torch imports
77
import torch
@@ -113,6 +113,8 @@ def main(args):
113113
input_ids = input_ids.cuda()
114114
labels = labels.cuda()
115115

116+
optimizer.zero_grad()
117+
116118
# forward
117119
pred = model(input_ids)
118120
tok_loss = F.cross_entropy(
@@ -123,12 +125,14 @@ def main(args):
123125
# backward on scaled loss to create scaled gradients
124126
scaler.scale(loss).backward()
125127

128+
# clip gradients (after unscaling gradients of the optimizer's params)
129+
scaler.unscale_(optimizer)
130+
model.clip_grad_norm_(args.max_norm)
131+
126132
# optimizer step
127-
# scaler.step() first unscales gradients of the optimizer's params.
128-
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
133+
# If gradients don't contain infs/NaNs, optimizer.step() is then called;
129134
# otherwise, optimizer.step() is skipped.
130135
scaler.step(optimizer)
131-
optimizer.zero_grad()
132136

133137
# updates the scale for next iteration
134138
scaler.update()
@@ -168,6 +172,9 @@ def main(args):
168172
"--optimizer", type=str, default="AdamW", help="optimizer to use"
169173
)
170174
parser.add_argument("--lr", type=float, default=2e-5, help="learning rate to use")
175+
parser.add_argument(
176+
"--max_norm", type=Union[float, int], default=1.0, help="max norm for gradient clipping"
177+
)
171178
parser.add_argument(
172179
"--steps", type=int, default=-1, help="how many train steps to run"
173180
)

0 commit comments

Comments
 (0)