Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions torchtrain/lr_scheduling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

from torchtrain.logging_utils import rank0_log


class LinearScheduler:
def __init__(self, args):
self.lr_max = args.lr
self.lr_min = args.lr / 10
self.lr_warmup_pct = 0.10
# enforce min of 2 steps for warmup
self.warmup_steps = max(int(args.steps * self.lr_warmup_pct), 2)

rank0_log(
f"LR Warmup Schedule: {self.lr_min} -> {self.lr_max} with {self.warmup_steps} warmup steps"
)
self.decay_steps = args.steps - self.warmup_steps
self.curr_lr = 0

def set_lr(self, optimizer, step):
"""Set the learning rate for the optimizer"""
if step < self.warmup_steps:
self.curr_lr = self.lr_max * (step / self.warmup_steps)
else:
self.curr_lr = self.lr_min + (
(self.lr_max - self.lr_min)
* (1 - (step - self.warmup_steps) / self.decay_steps)
)
# apply across all optim groups
for param_group in optimizer.param_groups:
param_group["lr"] = self.curr_lr
rank0_log(f"Optimizer LR Update: {step=}, lr = {round(self.curr_lr,6)}")
6 changes: 5 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def main(args):

# build optimizer after apply parallelisms to the model
# TODO: add scheduler if needed
from torchtrain.lr_scheduling import LinearScheduler
scheduler = LinearScheduler(args)

optimizer = build_optimizer(model, args)

# TODO: add metrics
Expand All @@ -88,6 +91,7 @@ def main(args):
with maybe_run_profiler() as torch_profiler:
while train_state.step < args.steps or args.steps == -1:
train_state.step += 1
scheduler.set_lr(optimizer, train_state.step)
# get batch
batch = next(iter(data_loader))
input_ids, labels = batch
Expand Down Expand Up @@ -143,7 +147,7 @@ def main(args):
parser.add_argument(
"--optimizer", type=str, default="AdamW", help="optimizer to use"
)
parser.add_argument("--lr", type=float, default=2e-5, help="learning rate to use")
parser.add_argument("--lr", type=float, default=8e-4, help="learning rate to use")
parser.add_argument(
"--steps", type=int, default=-1, help="how many train steps to run"
)
Expand Down