diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 141c08d09a..b9d62fc892 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -741,19 +741,27 @@ def main(args: FlatArguments, tc: TokenizerConfig): if "epoch" in training_difference: starting_epoch = int(training_difference.replace("epoch_", "")) + 1 - resume_step = None + resume_batch_idx = 0 completed_steps = starting_epoch * num_update_steps_per_epoch else: # need to multiply `gradient_accumulation_steps` to reflect real steps - resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps - starting_epoch = resume_step // len(train_dataloader) - completed_steps = resume_step // args.gradient_accumulation_steps - resume_step -= starting_epoch * len(train_dataloader) + resume_batch_idx = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps + starting_epoch = resume_batch_idx // len(train_dataloader) + completed_steps = resume_batch_idx // args.gradient_accumulation_steps + resume_batch_idx -= starting_epoch * len(train_dataloader) - print(f"Starting from epoch {starting_epoch} and step {completed_steps}.") + else: + resume_batch_idx = 0 + + resume_step = resume_batch_idx // args.gradient_accumulation_steps + + print(f"Starting {starting_epoch=}, {resume_batch_idx=}, {resume_step=}, {completed_steps=}.") # update the progress_bar if load from checkpoint progress_bar.update(completed_steps) local_total_tokens = torch.tensor(0, dtype=torch.int64, device=accelerator.device) + local_pred_tokens = torch.tensor(0, dtype=torch.int64, device=accelerator.device) + local_total_tokens_this_log_period = torch.tensor(0, dtype=torch.int64, device=accelerator.device) + local_pred_tokens_this_log_period = torch.tensor(0, dtype=torch.int64, device=accelerator.device) total_token_including_padding = torch.tensor(0, dtype=torch.int64, device=accelerator.device) start_time = time.time() skipped_batches = False @@ -762,27 +770,33 @@ def main(args: FlatArguments, tc: TokenizerConfig): train_dataloader.set_epoch(epoch) total_loss = 0 total_aux_loss = 0 - if last_checkpoint_path and resume_step is not None and not skipped_batches: + if last_checkpoint_path and resume_batch_idx and not skipped_batches: # We skip the first `n` batches in the dataloader when resuming from a checkpoint. - active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) + active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_batch_idx) # Only perform this skip once skipped_batches = True else: active_dataloader = train_dataloader for step, batch in enumerate(active_dataloader): + pred_tokens_in_batch = (batch["labels"] != -100).sum() if "attention_mask" in batch: - local_total_tokens += batch["attention_mask"].sum() + tokens_in_batch = batch["attention_mask"].sum() total_token_including_padding += batch["attention_mask"].numel() elif "position_ids" in batch: tokens_in_batch = batch["position_ids"].numel() - local_total_tokens += tokens_in_batch total_token_including_padding += tokens_in_batch elif "cu_seq_lens_q" in batch: tokens_in_batch = batch["cu_seq_lens_q"][-1] - local_total_tokens += tokens_in_batch total_token_including_padding += tokens_in_batch else: - raise ValueError(f"Expected attention_mask or position_ids or cu_seq_lens_q in batch, found {batch=}") + raise ValueError( + f"Expected attention_mask or position_ids or cu_seq_lens_q in batch, found {batch=}" + ) + local_total_tokens += tokens_in_batch + local_total_tokens_this_log_period += tokens_in_batch + local_pred_tokens += pred_tokens_in_batch + local_pred_tokens_this_log_period += pred_tokens_in_batch + with accelerator.accumulate(model): if args.load_balancing_loss: outputs = model(**batch, use_cache=False, output_router_logits=True) @@ -839,23 +853,101 @@ def main(args: FlatArguments, tc: TokenizerConfig): progress_bar.update(1) completed_steps += 1 if args.logging_steps and completed_steps % args.logging_steps == 0: - avg_loss = ( - accelerator.gather(total_loss).mean().item() + sum_loss = accelerator.gather(total_loss).sum().item() + total_tokens = accelerator.gather(local_total_tokens).sum().item() + total_pred_tokens = accelerator.gather(local_pred_tokens).sum().item() + total_tokens_including_padding = ( + accelerator.gather(total_token_including_padding).sum().item() + ) + total_tokens_this_log_period = accelerator.gather(local_total_tokens_this_log_period).sum().item() + local_total_tokens_this_log_period.zero_() + pred_tokens_this_log_period = accelerator.gather(local_pred_tokens_this_log_period).sum().item() + local_pred_tokens_this_log_period.zero_() + + avg_tokens_per_batch = ( + total_tokens + / accelerator.num_processes + / args.per_device_train_batch_size / args.gradient_accumulation_steps - / args.logging_steps + / completed_steps + ) + avg_tokens_per_batch_including_padding = ( + total_tokens_including_padding + / accelerator.num_processes + / args.per_device_train_batch_size + / args.gradient_accumulation_steps + / completed_steps + ) + avg_pred_tokens_per_batch = ( + total_pred_tokens + / accelerator.num_processes + / args.per_device_train_batch_size + / args.gradient_accumulation_steps + / completed_steps ) - total_tokens = accelerator.gather(local_total_tokens).sum().item() - total_tokens_including_padding = accelerator.gather(total_token_including_padding).sum().item() metrics_to_log = { "learning_rate": lr_scheduler.get_last_lr()[0], - "train_loss": avg_loss, "total_tokens": total_tokens, - "per_device_tps": total_tokens / accelerator.num_processes / (time.time() - start_time), "total_tokens_including_padding": total_tokens_including_padding, + "total_pred_tokens": total_pred_tokens, + "total_tokens_this_log_period": total_tokens_this_log_period, + "avg_tokens_per_batch": avg_tokens_per_batch, + "avg_tokens_per_batch_including_padding": avg_tokens_per_batch_including_padding, + "avg_pred_tokens_per_batch": avg_pred_tokens_per_batch, + "per_device_tps": total_tokens + / accelerator.num_processes + / (time.time() - start_time), "per_device_tps_including_padding": total_tokens_including_padding / accelerator.num_processes / (time.time() - start_time), + "reserved_mem_GiB": torch.cuda.max_memory_reserved( + device=torch.cuda.current_device() + ) + / 2**30, + "allocated_mem_GiB": torch.cuda.max_memory_allocated( + device=torch.cuda.current_device() + ) + / 2**30, } + + # [Loss Reporting] + # + # It is useful to handle loss-reporting for the "mean" and "sum" loss cases + # differently. Cases: + # 1) "mean" loss: `sum_loss` takes individual losses which were *averaged* over the toks in their + # sequence and sums them over all fwd passes in the logging period. We instead want the avg over + # these passes. Report avg_loss = sum_loss / total_fwd_passes, which is roughly independent of + # global batch size. + # 2) "sum" loss: `sum_loss` takes individual losses which were *summed* over the toks in their + # sequence and sums them over all fwd passes in the logging period. We want both the avg over each + # optimizer step (which scales with the global batch size) and the average loss per token (which is + # roughly independent of global batch size) . Report avg_sum_loss = sum_loss / total_optim_steps + # and avg_loss = sum_loss / total_tokens_this_log_period. The latter is roughly comparable to what + # we report in the "mean" loss case. + if args.reduce_loss == "mean": + total_fwd_passes = ( + args.logging_steps * args.gradient_accumulation_steps * accelerator.num_processes + ) + avg_loss = sum_loss / total_fwd_passes + metrics_to_log["train_loss"] = avg_loss + else: + avg_loss_per_total_tok = sum_loss / total_tokens_this_log_period + # The loss per pred tok is the closest analogue to what we report as the + # avg_loss in the "mean" case + avg_loss = sum_loss / pred_tokens_this_log_period + total_optim_steps = args.logging_steps * accelerator.num_processes + avg_sum_loss = sum_loss / total_optim_steps + metrics_to_log["train_sum_loss"] = avg_sum_loss + metrics_to_log["train_loss_per_total_tok"] = avg_loss_per_total_tok + metrics_to_log["train_loss_per_pred_tok"] = avg_loss + + sec_per_step = (time.time() - start_time) / (completed_steps - resume_step) + steps_remaining = args.max_train_steps - completed_steps + secs_remaining = steps_remaining * sec_per_step + accelerator.print( + f"Approx. time remaining: {timedelta(seconds=secs_remaining)}. {args.max_train_steps=}, {completed_steps=}, {steps_remaining=}" + ) + if args.load_balancing_loss: avg_aux_loss = ( accelerator.gather(total_aux_loss).mean().item() @@ -870,6 +962,7 @@ def main(args: FlatArguments, tc: TokenizerConfig): logger.info( f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, TPS: {total_tokens / (time.time() - start_time)}" ) + accelerator.print(f"{metrics_to_log=}") if args.with_tracking: accelerator.log(metrics_to_log, step=completed_steps) total_loss = 0