diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 6a998900e3..587754a654 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -373,6 +373,13 @@ class FlatArguments: "help": "Whether to clean up all previous checkpoints at the end of the run.", }, ) + final_lr_ratio: Optional[float] = field( + default=None, + metadata={ + "help": "Set the final lr value at the end of training to be final_lr_ratio * learning_rate." + " Only for linear schedulers, currently." + }, + ) def __post_init__(self): if self.reduce_loss not in ["mean", "sum"]: @@ -387,6 +394,11 @@ def __post_init__(self): raise ValueError("Cannot provide two dataset selection mechanisms.") if self.try_launch_beaker_eval_jobs and not self.push_to_hub: raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.") + if self.final_lr_ratio is not None: + if self.lr_scheduler_type != "linear": + raise NotImplementedError("final_lr_ratio only currently implemented for linear schedulers") + if not (1.0 >= self.final_lr_ratio >= 0.0): + raise ValueError(f"final_lr_ratio must be between 0 and 1, not {self.final_lr_ratio=}") def main(args: FlatArguments, tc: TokenizerConfig): @@ -708,11 +720,19 @@ def main(args: FlatArguments, tc: TokenizerConfig): num_training_steps_for_scheduler = ( args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes ) + + num_warmup_steps = int(num_training_steps_for_scheduler * args.warmup_ratio) + if args.final_lr_ratio is not None and args.lr_scheduler_type == "linear": + # Correct num_training_steps_for_scheduler to respect final_lr_ratio for a linear scheduler + num_training_steps_for_scheduler = ( + num_training_steps_for_scheduler - args.final_lr_ratio * num_warmup_steps + ) / (1 - args.final_lr_ratio) + lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, num_training_steps=num_training_steps_for_scheduler, - num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio), + num_warmup_steps=num_warmup_steps, ) # Prepare everything with `accelerator`. model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(