From 2f03afc5cea370ba29c9c3e58016bed9c82ae0d2 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 30 Jan 2024 15:17:42 -0800 Subject: [PATCH 1/3] [torchtrain] add gradient scaler [ghstack-poisoned] --- train.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 073e5b1755..08709ca2b4 100644 --- a/train.py +++ b/train.py @@ -6,6 +6,7 @@ # torch imports import torch import torch.nn.functional as F +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.utils.data import DataLoader from torchtrain.profiling import maybe_run_profiler @@ -82,6 +83,8 @@ def main(args): train_state = TrainState() + scaler = ShardedGradScaler() + # train loop model.train() @@ -101,14 +104,19 @@ def main(args): ) loss = tok_loss.mean() - # backward - loss.backward() - # TODO: add grad scaler + # backward on scaled loss to create scaled gradients + scaler.scale(loss).backward() # optimizer step - optimizer.step() + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) optimizer.zero_grad() + # updates the scale for next iteration + scaler.update() + # if profiler is active if torch_profiler: torch_profiler.step() From 255085f9bd1e5b6d7fe6c0aa68fb8c4e661d945f Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Tue, 30 Jan 2024 18:33:48 -0800 Subject: [PATCH 2/3] Update on "[torchtrain] add gradient scaler" [ghstack-poisoned] --- train.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 08709ca2b4..8a3a096163 100644 --- a/train.py +++ b/train.py @@ -7,6 +7,7 @@ import torch import torch.nn.functional as F from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader from torchtrain.profiling import maybe_run_profiler @@ -72,6 +73,15 @@ def main(args): # TODO: add scheduler if needed optimizer = build_optimizer(model, args) + # apply gradient scaling if mixed precision training is enabled with fp16 param dtype + if isinstance(model, FSDP) and model.mixed_precision.param_dtype == torch.float16: + enable_grad_scaling = True + rank0_log(f"Enabling gradient scaling for mixed precision training.") + else: + enable_grad_scaling = False + rank0_log("Gradient scaling not enabled.") + scaler = ShardedGradScaler(enabled=enable_grad_scaling) + # TODO: add metrics # torch.compile model for improved performance @@ -83,8 +93,6 @@ def main(args): train_state = TrainState() - scaler = ShardedGradScaler() - # train loop model.train() From 059113957e89b208e3a5f542cc958d4b48da6f4a Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 31 Jan 2024 13:20:32 -0800 Subject: [PATCH 3/3] Update on "[torchtrain] add gradient scaler" As titled. Only enable the gradient scaler if mixed precision training is used with parameter dtype being fp16. [ghstack-poisoned] --- train.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/train.py b/train.py index 8a3a096163..447fc7436d 100644 --- a/train.py +++ b/train.py @@ -42,6 +42,18 @@ def build_optimizer(model, args): return optimizer +def build_grad_scaler(model): + # apply gradient scaling if mixed precision training is enabled with fp16 param dtype + if model.mixed_precision.param_dtype == torch.float16: + enable_grad_scaling = True + rank0_log(f"Enabling gradient scaling for mixed precision training.") + else: + enable_grad_scaling = False + rank0_log("Gradient scaling not enabled.") + + return ShardedGradScaler(enabled=enable_grad_scaling) + + def main(args): init_logger() @@ -69,18 +81,14 @@ def main(args): # apply PTD parallelisms + AC model = models_parallelize_fns[model_name](model, args) + # to use FSDP-customized gradient scaler and gradient clipping solutions + assert isinstance(model, FSDP) + # build optimizer after apply parallelisms to the model # TODO: add scheduler if needed optimizer = build_optimizer(model, args) - # apply gradient scaling if mixed precision training is enabled with fp16 param dtype - if isinstance(model, FSDP) and model.mixed_precision.param_dtype == torch.float16: - enable_grad_scaling = True - rank0_log(f"Enabling gradient scaling for mixed precision training.") - else: - enable_grad_scaling = False - rank0_log("Gradient scaling not enabled.") - scaler = ShardedGradScaler(enabled=enable_grad_scaling) + scaler = build_grad_scaler(model) # TODO: add metrics