File tree Expand file tree Collapse file tree 1 file changed +12
-4
lines changed
Expand file tree Collapse file tree 1 file changed +12
-4
lines changed Original file line number Diff line number Diff line change 66# torch imports
77import torch
88import torch .nn .functional as F
9+ from torch .distributed .fsdp .sharded_grad_scaler import ShardedGradScaler
910from torch .utils .data import DataLoader
1011
1112from torchtrain .profiling import maybe_run_profiler
@@ -82,6 +83,8 @@ def main(args):
8283
8384 train_state = TrainState ()
8485
86+ scaler = ShardedGradScaler ()
87+
8588 # train loop
8689 model .train ()
8790
@@ -101,14 +104,19 @@ def main(args):
101104 )
102105 loss = tok_loss .mean ()
103106
104- # backward
105- loss .backward ()
106- # TODO: add grad scaler
107+ # backward on scaled loss to create scaled gradients
108+ scaler .scale (loss ).backward ()
107109
108110 # optimizer step
109- optimizer .step ()
111+ # scaler.step() first unscales gradients of the optimizer's params.
112+ # If gradients don't contain infs/NaNs, optimizer.step() is then called,
113+ # otherwise, optimizer.step() is skipped.
114+ scaler .step (optimizer )
110115 optimizer .zero_grad ()
111116
117+ # updates the scale for next iteration
118+ scaler .update ()
119+
112120 # if profiler is active
113121 if torch_profiler :
114122 torch_profiler .step ()
You can’t perform that action at this time.
0 commit comments