Skip to content

Commit 3fe1ec6

Browse files
committed
[torchtrain] add gradient scaler
ghstack-source-id: 3d250d1 Pull Request resolved: #25
1 parent 2c7e48b commit 3fe1ec6

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

train.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# torch imports
77
import torch
88
import torch.nn.functional as F
9+
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
910
from torch.utils.data import DataLoader
1011

1112
from torchtrain.profiling import maybe_run_profiler
@@ -82,6 +83,8 @@ def main(args):
8283

8384
train_state = TrainState()
8485

86+
scaler = ShardedGradScaler()
87+
8588
# train loop
8689
model.train()
8790

@@ -101,14 +104,19 @@ def main(args):
101104
)
102105
loss = tok_loss.mean()
103106

104-
# backward
105-
loss.backward()
106-
# TODO: add grad scaler
107+
# backward on scaled loss to create scaled gradients
108+
scaler.scale(loss).backward()
107109

108110
# optimizer step
109-
optimizer.step()
111+
# scaler.step() first unscales gradients of the optimizer's params.
112+
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
113+
# otherwise, optimizer.step() is skipped.
114+
scaler.step(optimizer)
110115
optimizer.zero_grad()
111116

117+
# updates the scale for next iteration
118+
scaler.update()
119+
112120
# if profiler is active
113121
if torch_profiler:
114122
torch_profiler.step()

0 commit comments

Comments
 (0)