Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# torch imports
import torch
import torch.nn.functional as F
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader

from torchtrain.profiling import maybe_run_profiler
Expand Down Expand Up @@ -71,6 +73,15 @@ def main(args):
# TODO: add scheduler if needed
optimizer = build_optimizer(model, args)

# apply gradient scaling if mixed precision training is enabled with fp16 param dtype
if isinstance(model, FSDP) and model.mixed_precision.param_dtype == torch.float16:
enable_grad_scaling = True
rank0_log(f"Enabling gradient scaling for mixed precision training.")
else:
enable_grad_scaling = False
rank0_log("Gradient scaling not enabled.")
scaler = ShardedGradScaler(enabled=enable_grad_scaling)

# TODO: add metrics

# torch.compile model for improved performance
Expand Down Expand Up @@ -101,14 +112,19 @@ def main(args):
)
loss = tok_loss.mean()

# backward
loss.backward()
# TODO: add grad scaler
# backward on scaled loss to create scaled gradients
scaler.scale(loss).backward()

# optimizer step
optimizer.step()
# scaler.step() first unscales gradients of the optimizer's params.
# If gradients don't contain infs/NaNs, optimizer.step() is then called,
# otherwise, optimizer.step() is skipped.
scaler.step(optimizer)
optimizer.zero_grad()

# updates the scale for next iteration
scaler.update()

# if profiler is active
if torch_profiler:
torch_profiler.step()
Expand Down