Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,14 @@ class FlatArguments:
oe_eval_max_length: int = 4096
"""the max generation length for evaluation for oe-eval"""

final_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Set the final lr value at the end of training to be final_lr_ratio * learning_rate."
" Only for linear schedulers, currently."
},
)

def __post_init__(self):
if self.reduce_loss not in ["mean", "sum"]:
raise ValueError("reduce_loss must be either 'mean' or 'sum'")
Expand All @@ -375,6 +383,11 @@ def __post_init__(self):
raise ValueError("Cannot provide two dataset selection mechanisms.")
if self.try_launch_beaker_eval_jobs and not self.push_to_hub:
raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.")
if self.final_lr_ratio is not None:
if self.lr_scheduler_type != "linear":
raise NotImplementedError("final_lr_ratio only currently implemented for linear schedulers")
if not (1.0 >= self.final_lr_ratio >= 0.0):
raise ValueError(f"final_lr_ratio must be between 0 and 1, not {self.final_lr_ratio=}")


def main(args: FlatArguments, tc: TokenizerConfig):
Expand Down Expand Up @@ -688,11 +701,19 @@ def main(args: FlatArguments, tc: TokenizerConfig):
num_training_steps_for_scheduler = (
args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes
)

num_warmup_steps = int(num_training_steps_for_scheduler * args.warmup_ratio)
if args.final_lr_ratio is not None and args.lr_scheduler_type == "linear":
# Correct num_training_steps_for_scheduler to respect final_lr_ratio for a linear scheduler
num_training_steps_for_scheduler = (
num_training_steps_for_scheduler - args.final_lr_ratio * num_warmup_steps
) / (1 - args.final_lr_ratio)

lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_training_steps=num_training_steps_for_scheduler,
num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio),
num_warmup_steps=num_warmup_steps,
)
# Prepare everything with `accelerator`.
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
Expand Down