From 7a0671acf05d42defcd0bfd1a82e0576f4ef64d3 Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Fri, 27 Jun 2025 16:52:25 -0400 Subject: [PATCH 1/6] final_lr_ratio --- open_instruct/finetune.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 6a998900e3..587754a654 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -373,6 +373,13 @@ class FlatArguments: "help": "Whether to clean up all previous checkpoints at the end of the run.", }, ) + final_lr_ratio: Optional[float] = field( + default=None, + metadata={ + "help": "Set the final lr value at the end of training to be final_lr_ratio * learning_rate." + " Only for linear schedulers, currently." + }, + ) def __post_init__(self): if self.reduce_loss not in ["mean", "sum"]: @@ -387,6 +394,11 @@ def __post_init__(self): raise ValueError("Cannot provide two dataset selection mechanisms.") if self.try_launch_beaker_eval_jobs and not self.push_to_hub: raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.") + if self.final_lr_ratio is not None: + if self.lr_scheduler_type != "linear": + raise NotImplementedError("final_lr_ratio only currently implemented for linear schedulers") + if not (1.0 >= self.final_lr_ratio >= 0.0): + raise ValueError(f"final_lr_ratio must be between 0 and 1, not {self.final_lr_ratio=}") def main(args: FlatArguments, tc: TokenizerConfig): @@ -708,11 +720,19 @@ def main(args: FlatArguments, tc: TokenizerConfig): num_training_steps_for_scheduler = ( args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes ) + + num_warmup_steps = int(num_training_steps_for_scheduler * args.warmup_ratio) + if args.final_lr_ratio is not None and args.lr_scheduler_type == "linear": + # Correct num_training_steps_for_scheduler to respect final_lr_ratio for a linear scheduler + num_training_steps_for_scheduler = ( + num_training_steps_for_scheduler - args.final_lr_ratio * num_warmup_steps + ) / (1 - args.final_lr_ratio) + lr_scheduler = get_scheduler( name=args.lr_scheduler_type, optimizer=optimizer, num_training_steps=num_training_steps_for_scheduler, - num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio), + num_warmup_steps=num_warmup_steps, ) # Prepare everything with `accelerator`. model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( From 5b827c067ccf1cdf799ba81cc5a41679b37c9ff1 Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Fri, 27 Jun 2025 16:52:57 -0400 Subject: [PATCH 2/6] rm run_name, add_seed_and_date_to_exp_name prev-branch: padding-free-squashing-3 --- open_instruct/finetune.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 587754a654..b0e398d7b9 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -87,8 +87,6 @@ class FlatArguments: exp_name: str = os.path.basename(__file__)[: -len(".py")] """The name of this experiment""" - run_name: Optional[str] = None - """A unique name of this run""" do_not_randomize_output_dir: bool = False """By default the output directory will be randomized""" model_name_or_path: Optional[str] = field( @@ -380,6 +378,7 @@ class FlatArguments: " Only for linear schedulers, currently." }, ) + add_seed_and_date_to_exp_name: bool = True def __post_init__(self): if self.reduce_loss not in ["mean", "sum"]: @@ -437,9 +436,13 @@ def main(args: FlatArguments, tc: TokenizerConfig): # ------------------------------------------------------------ # Set up runtime variables - args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + + if args.add_seed_and_date_to_exp_name: + args.exp_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + else: + args.exp_name = args.exp_name if not args.do_not_randomize_output_dir: - args.output_dir = os.path.join(args.output_dir, args.run_name) + args.output_dir = os.path.join(args.output_dir, args.exp_name) logger.info("using the output directory: %s", args.output_dir) args.dataset_local_cache_dir = os.path.abspath(args.dataset_local_cache_dir) if is_beaker_job(): @@ -453,7 +456,7 @@ def main(args: FlatArguments, tc: TokenizerConfig): args.hf_entity = HfApi().whoami()["name"] args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" if args.hf_repo_revision is None: - args.hf_repo_revision = args.run_name + args.hf_repo_revision = args.exp_name args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" if is_beaker_job(): beaker_config = maybe_get_beaker_config() @@ -477,7 +480,7 @@ def main(args: FlatArguments, tc: TokenizerConfig): experiment_config, init_kwargs={ "wandb": { - "name": args.run_name, + "name": args.exp_name, "entity": args.wandb_entity, "tags": [args.exp_name] + get_wandb_tags(), } From fd7d49d805485c1a5af7c067de37b117f219904d Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Fri, 27 Jun 2025 21:49:50 -0400 Subject: [PATCH 3/6] additional_model_arguments prev-branch: padding-free-squashing-4 --- open_instruct/finetune.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index b0e398d7b9..af082161be 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -379,6 +379,10 @@ class FlatArguments: }, ) add_seed_and_date_to_exp_name: bool = True + additional_model_arguments: Optional[list[str]] = field( + default=None, + metadata={"help": "A list of key:val to be passed as additional model args."}, + ) def __post_init__(self): if self.reduce_loss not in ["mean", "sum"]: @@ -399,6 +403,24 @@ def __post_init__(self): if not (1.0 >= self.final_lr_ratio >= 0.0): raise ValueError(f"final_lr_ratio must be between 0 and 1, not {self.final_lr_ratio=}") + if self.additional_model_arguments is not None: + import re + + def maybe_convert_(x): + return ( + float(x) + if x.count(".") == 1 and re.sub(r"^-?.*\.", "", x, count=1).isnumeric() + else (int(x) if x.count(".") == 0 and re.sub("^-?", "", x).isnumeric() else x) + ) + + try: + self.additional_model_arguments = [x.split(":") for x in self.additional_model_arguments] + self.additional_model_arguments = {k: maybe_convert_(v) for k, v in self.additional_model_arguments} + except IndexError: + raise ValueError("Malformed additional model arguments. Should be space-delimited list of key:val.") + else: + self.additional_model_arguments = {} + def main(args: FlatArguments, tc: TokenizerConfig): # ------------------------------------------------------------ @@ -549,12 +571,14 @@ def main(args: FlatArguments, tc: TokenizerConfig): args.config_name, revision=args.model_revision, trust_remote_code=tc.trust_remote_code, + **args.additional_model_arguments, ) elif args.model_name_or_path: config = AutoConfig.from_pretrained( args.model_name_or_path, revision=args.model_revision, trust_remote_code=tc.trust_remote_code, + **args.additional_model_arguments, ) else: raise ValueError( From f5e19fafa0936326987ce758e5f1389442ff5633 Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Wed, 25 Jun 2025 10:28:14 -0400 Subject: [PATCH 4/6] sync_each_batch=True grad acc prev-branch: padding-free-squashing-5 --- open_instruct/finetune.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index af082161be..d212d80bdf 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -40,6 +40,7 @@ import torch import transformers from accelerate import Accelerator, DataLoaderConfiguration +from accelerate.accelerator import GradientAccumulationPlugin from accelerate.logging import get_logger from accelerate.utils import InitProcessGroupKwargs, set_seed from huggingface_hub import HfApi @@ -434,11 +435,15 @@ def main(args: FlatArguments, tc: TokenizerConfig): # if you get timeouts (e.g. due to long tokenization) increase this. timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.timeout)) dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True) + accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, dataloader_config=dataloader_config, **accelerator_log_kwargs, kwargs_handlers=[timeout_kwargs], + gradient_accumulation_plugin=GradientAccumulationPlugin( + num_steps=args.gradient_accumulation_steps, + sync_each_batch=True, + ), ) # ------------------------------------------------------------ From 96c734b44a82ce7a602ff9a6aa45956f6fb29d30 Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Tue, 24 Jun 2025 15:25:09 -0400 Subject: [PATCH 5/6] no grad acc averaging for sum losses prev-branch: padding-free-squashing-6 --- open_instruct/finetune.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index d212d80bdf..800e2b4474 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -887,7 +887,20 @@ def main(args: FlatArguments, tc: TokenizerConfig): loss += aux_loss # We keep track of the loss at each logged step total_loss += loss.detach().float() - accelerator.backward(loss) + + # [Pre-backwards scalings] + # accelerator.backward internally divides by `gradient_accumulation_steps`, which is + # only the right thing to do for `mean` losses. For "sum" losses, we counteract this + # by multiplying the loss by `gradient_accumulation_steps` before the backwards + # call. Additionally, DeepSpeed/FSDP average the gradients across processes, whereas + # we should be summing gradients for a "sum" loss, hence we also multiply by the + # world size in this latter case. + accelerator.backward( + loss + if args.reduce_loss == "mean" + else loss * args.gradient_accumulation_steps * accelerator.num_processes + ) + if args.load_balancing_loss: total_aux_loss += aux_loss.detach().float() # clip gradient norm. don't do this with deepspeed From ef0da352eec9159b1c268748224b623146ba6b19 Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Wed, 25 Jun 2025 10:28:52 -0400 Subject: [PATCH 6/6] extra reporting prev-branch: padding-free-squashing-7 --- open_instruct/finetune.py | 92 ++++++++++++++++++++++++++++++++++----- 1 file changed, 80 insertions(+), 12 deletions(-) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 800e2b4474..ac1e4a04e6 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -808,7 +808,7 @@ def main(args: FlatArguments, tc: TokenizerConfig): if "epoch" in training_difference: starting_epoch = int(training_difference.replace("epoch_", "")) + 1 - resume_step = None + resume_step = 0 completed_steps = starting_epoch * num_update_steps_per_epoch else: # need to multiply `gradient_accumulation_steps` to reflect real steps @@ -817,10 +817,14 @@ def main(args: FlatArguments, tc: TokenizerConfig): completed_steps = resume_step // args.gradient_accumulation_steps resume_step -= starting_epoch * len(train_dataloader) + else: + resume_step = 0 + print(f"Starting from epoch {starting_epoch} and step {completed_steps}.") # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) local_total_tokens = torch.tensor(0, dtype=torch.int64, device=accelerator.device) + local_total_tokens_this_log_period = torch.tensor(0, dtype=torch.int64, device=accelerator.device) total_token_including_padding = torch.tensor(0, dtype=torch.int64, device=accelerator.device) start_time = time.time() skipped_batches = False @@ -829,7 +833,7 @@ def main(args: FlatArguments, tc: TokenizerConfig): train_dataloader.set_epoch(epoch) total_loss = 0 total_aux_loss = 0 - if last_checkpoint_path and resume_step is not None and not skipped_batches: + if last_checkpoint_path and resume_step and not skipped_batches: # We skip the first `n` batches in the dataloader when resuming from a checkpoint. active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) # Only perform this skip once @@ -838,20 +842,21 @@ def main(args: FlatArguments, tc: TokenizerConfig): active_dataloader = train_dataloader for step, batch in enumerate(active_dataloader): if "attention_mask" in batch: - local_total_tokens += batch["attention_mask"].sum() + tokens_in_batch = batch["attention_mask"].sum() total_token_including_padding += batch["attention_mask"].numel() elif "position_ids" in batch: tokens_in_batch = batch["position_ids"].numel() - local_total_tokens += tokens_in_batch total_token_including_padding += tokens_in_batch elif "cu_seq_lens_q" in batch: tokens_in_batch = batch["cu_seq_lens_q"][-1] - local_total_tokens += tokens_in_batch total_token_including_padding += tokens_in_batch else: raise ValueError( f"Expected attention_mask or position_ids or cu_seq_lens_q in batch, found {batch=}" ) + local_total_tokens += tokens_in_batch + local_total_tokens_this_log_period += tokens_in_batch + with accelerator.accumulate(model): if args.load_balancing_loss: outputs = model(**batch, use_cache=False, output_router_logits=True) @@ -915,23 +920,85 @@ def main(args: FlatArguments, tc: TokenizerConfig): progress_bar.update(1) completed_steps += 1 if args.logging_steps and completed_steps % args.logging_steps == 0: - avg_loss = ( - accelerator.gather(total_loss).mean().item() + sum_loss = accelerator.gather(total_loss).sum().item() + total_tokens = accelerator.gather(local_total_tokens).sum().item() + total_tokens_including_padding = ( + accelerator.gather(total_token_including_padding).sum().item() + ) + total_tokens_this_log_period = accelerator.gather(local_total_tokens_this_log_period).sum().item() + local_total_tokens_this_log_period.zero_() + + avg_tokens_per_batch = ( + total_tokens + / accelerator.num_processes + / args.per_device_train_batch_size / args.gradient_accumulation_steps - / args.logging_steps + / completed_steps + ) + avg_tokens_per_batch_including_padding = ( + total_tokens_including_padding + / accelerator.num_processes + / args.per_device_train_batch_size + / args.gradient_accumulation_steps + / completed_steps ) - total_tokens = accelerator.gather(local_total_tokens).sum().item() - total_tokens_including_padding = accelerator.gather(total_token_including_padding).sum().item() metrics_to_log = { "learning_rate": lr_scheduler.get_last_lr()[0], - "train_loss": avg_loss, "total_tokens": total_tokens, - "per_device_tps": total_tokens / accelerator.num_processes / (time.time() - start_time), + "total_tokens_this_log_period": total_tokens_this_log_period, + "avg_tokens_per_batch": avg_tokens_per_batch, + "avg_tokens_per_batch_including_padding": avg_tokens_per_batch_including_padding, + "per_device_tps": total_tokens + / accelerator.num_processes + / (time.time() - start_time), "total_tokens_including_padding": total_tokens_including_padding, "per_device_tps_including_padding": total_tokens_including_padding / accelerator.num_processes / (time.time() - start_time), + "reserved_mem_GiB": torch.cuda.max_memory_reserved( + device=torch.cuda.current_device() + ) + / 2**30, + "allocated_mem_GiB": torch.cuda.max_memory_allocated( + device=torch.cuda.current_device() + ) + / 2**30, } + + # [Loss Reporting] + # + # It is useful to handle loss-reporting for the "mean" and "sum" loss cases + # differently. Cases: + # 1) "mean" loss: `sum_loss` takes individual losses which were *averaged* over the toks in their + # sequence and sums them over all fwd passes in the logging period. We instead want the avg over + # these passes. Report avg_loss = sum_loss / total_fwd_passes, which is roughly independent of + # global batch size. + # 2) "sum" loss: `sum_loss` takes individual losses which were *summed* over the toks in their + # sequence and sums them over all fwd passes in the logging period. We want both the avg over each + # optimizer step (which scales with the global batch size) and the average loss per token (which is + # roughly independent of global batch size) . Report avg_sum_loss = sum_loss / total_optim_steps + # and avg_loss = sum_loss / total_tokens_this_log_period. The latter is roughly comparable to what + # we report in the "mean" loss case. + if args.reduce_loss == "mean": + total_fwd_passes = ( + args.logging_steps * args.gradient_accumulation_steps * accelerator.num_processes + ) + avg_loss = sum_loss / total_fwd_passes + else: + avg_loss = sum_loss / total_tokens_this_log_period + total_optim_steps = args.logging_steps * accelerator.num_processes + avg_sum_loss = sum_loss / total_optim_steps + metrics_to_log["train_sum_loss"] = avg_sum_loss + + metrics_to_log["train_loss"] = avg_loss + + sec_per_step = (time.time() - start_time) / (completed_steps - resume_step) + steps_remaining = args.max_train_steps - completed_steps + secs_remaining = steps_remaining * sec_per_step + accelerator.print( + f"Approx. time remaining: {timedelta(seconds=secs_remaining)}. {args.max_train_steps=}, {completed_steps=}, {steps_remaining=}" + ) + if args.load_balancing_loss: avg_aux_loss = ( accelerator.gather(total_aux_loss).mean().item() @@ -946,6 +1013,7 @@ def main(args: FlatArguments, tc: TokenizerConfig): logger.info( f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, TPS: {total_tokens / (time.time() - start_time)}" ) + accelerator.print(f"{metrics_to_log=}") if args.with_tracking: accelerator.log( metrics_to_log,