diff --git a/train.py b/train.py index 073e5b1755..447fc7436d 100644 --- a/train.py +++ b/train.py @@ -6,6 +6,8 @@ # torch imports import torch import torch.nn.functional as F +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader from torchtrain.profiling import maybe_run_profiler @@ -40,6 +42,18 @@ def build_optimizer(model, args): return optimizer +def build_grad_scaler(model): + # apply gradient scaling if mixed precision training is enabled with fp16 param dtype + if model.mixed_precision.param_dtype == torch.float16: + enable_grad_scaling = True + rank0_log(f"Enabling gradient scaling for mixed precision training.") + else: + enable_grad_scaling = False + rank0_log("Gradient scaling not enabled.") + + return ShardedGradScaler(enabled=enable_grad_scaling) + + def main(args): init_logger() @@ -67,10 +81,15 @@ def main(args): # apply PTD parallelisms + AC model = models_parallelize_fns[model_name](model, args) + # to use FSDP-customized gradient scaler and gradient clipping solutions + assert isinstance(model, FSDP) + # build optimizer after apply parallelisms to the model # TODO: add scheduler if needed optimizer = build_optimizer(model, args) + scaler = build_grad_scaler(model) + # TODO: add metrics # torch.compile model for improved performance @@ -101,14 +120,19 @@ def main(args): ) loss = tok_loss.mean() - # backward - loss.backward() - # TODO: add grad scaler + # backward on scaled loss to create scaled gradients + scaler.scale(loss).backward() # optimizer step - optimizer.step() + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) optimizer.zero_grad() + # updates the scale for next iteration + scaler.update() + # if profiler is active if torch_profiler: torch_profiler.step()