Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,13 @@ class FlatArguments:
"help": "Whether to clean up all previous checkpoints at the end of the run.",
},
)
final_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Set the final lr value at the end of training to be final_lr_ratio * learning_rate."
" Only for linear schedulers, currently."
},
)

def __post_init__(self):
if self.reduce_loss not in ["mean", "sum"]:
Expand All @@ -387,6 +394,11 @@ def __post_init__(self):
raise ValueError("Cannot provide two dataset selection mechanisms.")
if self.try_launch_beaker_eval_jobs and not self.push_to_hub:
raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.")
if self.final_lr_ratio is not None:
if self.lr_scheduler_type != "linear":
raise NotImplementedError("final_lr_ratio only currently implemented for linear schedulers")
if not (1.0 >= self.final_lr_ratio >= 0.0):
raise ValueError(f"final_lr_ratio must be between 0 and 1, not {self.final_lr_ratio=}")


def main(args: FlatArguments, tc: TokenizerConfig):
Expand Down Expand Up @@ -708,11 +720,19 @@ def main(args: FlatArguments, tc: TokenizerConfig):
num_training_steps_for_scheduler = (
args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes
)

num_warmup_steps = int(num_training_steps_for_scheduler * args.warmup_ratio)
if args.final_lr_ratio is not None and args.lr_scheduler_type == "linear":
# Correct num_training_steps_for_scheduler to respect final_lr_ratio for a linear scheduler
num_training_steps_for_scheduler = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_training_steps_for_scheduler is not used anwhere else except get_scheduler right?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, it's just defined here, and maybe updated if the user specifies final_lr_ratio and is using a linear scheduler.

num_training_steps_for_scheduler - args.final_lr_ratio * num_warmup_steps
) / (1 - args.final_lr_ratio)

lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_training_steps=num_training_steps_for_scheduler,
num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio),
num_warmup_steps=num_warmup_steps,
)
# Prepare everything with `accelerator`.
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
Expand Down