diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index a87a9cbce0..dfb7aa1966 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -362,6 +362,14 @@ class FlatArguments: oe_eval_max_length: int = 4096 """the max generation length for evaluation for oe-eval""" + final_lr_ratio: Optional[float] = field( + default=None, + metadata={ + "help": "Set the final lr value at the end of training to be final_lr_ratio * learning_rate." + " Only for linear schedulers, currently." + }, + ) + def __post_init__(self): if self.reduce_loss not in ["mean", "sum"]: raise ValueError("reduce_loss must be either 'mean' or 'sum'") @@ -375,6 +383,11 @@ def __post_init__(self): raise ValueError("Cannot provide two dataset selection mechanisms.") if self.try_launch_beaker_eval_jobs and not self.push_to_hub: raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.") + if self.final_lr_ratio is not None: + if self.lr_scheduler_type != "linear": + raise NotImplementedError("final_lr_ratio only currently implemented for linear schedulers") + if not (1.0 >= self.final_lr_ratio >= 0.0): + raise ValueError(f"final_lr_ratio must be between 0 and 1, not {self.final_lr_ratio=}") def main(args: FlatArguments, tc: TokenizerConfig): @@ -688,11 +701,19 @@ def main(args: FlatArguments, tc: TokenizerConfig): num_training_steps_for_scheduler = ( args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes ) + + num_warmup_steps = int(num_training_steps_for_scheduler * args.warmup_ratio) + if args.final_lr_ratio is not None and args.lr_scheduler_type == "linear": + # Correct num_training_steps_for_scheduler to respect final_lr_ratio for a linear scheduler + num_training_steps_for_scheduler = ( + num_training_steps_for_scheduler - args.final_lr_ratio * num_warmup_steps + ) / (1 - args.final_lr_ratio) + lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, num_training_steps=num_training_steps_for_scheduler, - num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio), + num_warmup_steps=num_warmup_steps, ) # Prepare everything with `accelerator`. model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(