Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 112 additions & 19 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,19 +741,27 @@ def main(args: FlatArguments, tc: TokenizerConfig):

if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
resume_batch_idx = 0
completed_steps = starting_epoch * num_update_steps_per_epoch
else:
# need to multiply `gradient_accumulation_steps` to reflect real steps
resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
starting_epoch = resume_step // len(train_dataloader)
completed_steps = resume_step // args.gradient_accumulation_steps
resume_step -= starting_epoch * len(train_dataloader)
resume_batch_idx = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
starting_epoch = resume_batch_idx // len(train_dataloader)
completed_steps = resume_batch_idx // args.gradient_accumulation_steps
resume_batch_idx -= starting_epoch * len(train_dataloader)

print(f"Starting from epoch {starting_epoch} and step {completed_steps}.")
else:
resume_batch_idx = 0

resume_step = resume_batch_idx // args.gradient_accumulation_steps

print(f"Starting {starting_epoch=}, {resume_batch_idx=}, {resume_step=}, {completed_steps=}.")
# update the progress_bar if load from checkpoint
progress_bar.update(completed_steps)
local_total_tokens = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
local_pred_tokens = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
local_total_tokens_this_log_period = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
local_pred_tokens_this_log_period = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
total_token_including_padding = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
start_time = time.time()
skipped_batches = False
Expand All @@ -762,27 +770,33 @@ def main(args: FlatArguments, tc: TokenizerConfig):
train_dataloader.set_epoch(epoch)
total_loss = 0
total_aux_loss = 0
if last_checkpoint_path and resume_step is not None and not skipped_batches:
if last_checkpoint_path and resume_batch_idx and not skipped_batches:
# We skip the first `n` batches in the dataloader when resuming from a checkpoint.
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_batch_idx)
# Only perform this skip once
skipped_batches = True
else:
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
pred_tokens_in_batch = (batch["labels"] != -100).sum()
if "attention_mask" in batch:
local_total_tokens += batch["attention_mask"].sum()
tokens_in_batch = batch["attention_mask"].sum()
total_token_including_padding += batch["attention_mask"].numel()
elif "position_ids" in batch:
tokens_in_batch = batch["position_ids"].numel()
local_total_tokens += tokens_in_batch
total_token_including_padding += tokens_in_batch
elif "cu_seq_lens_q" in batch:
tokens_in_batch = batch["cu_seq_lens_q"][-1]
local_total_tokens += tokens_in_batch
total_token_including_padding += tokens_in_batch
else:
raise ValueError(f"Expected attention_mask or position_ids or cu_seq_lens_q in batch, found {batch=}")
raise ValueError(
f"Expected attention_mask or position_ids or cu_seq_lens_q in batch, found {batch=}"
)
local_total_tokens += tokens_in_batch
local_total_tokens_this_log_period += tokens_in_batch
local_pred_tokens += pred_tokens_in_batch
local_pred_tokens_this_log_period += pred_tokens_in_batch

with accelerator.accumulate(model):
if args.load_balancing_loss:
outputs = model(**batch, use_cache=False, output_router_logits=True)
Expand Down Expand Up @@ -839,23 +853,101 @@ def main(args: FlatArguments, tc: TokenizerConfig):
progress_bar.update(1)
completed_steps += 1
if args.logging_steps and completed_steps % args.logging_steps == 0:
avg_loss = (
accelerator.gather(total_loss).mean().item()
sum_loss = accelerator.gather(total_loss).sum().item()
total_tokens = accelerator.gather(local_total_tokens).sum().item()
total_pred_tokens = accelerator.gather(local_pred_tokens).sum().item()
total_tokens_including_padding = (
accelerator.gather(total_token_including_padding).sum().item()
)
total_tokens_this_log_period = accelerator.gather(local_total_tokens_this_log_period).sum().item()
local_total_tokens_this_log_period.zero_()
pred_tokens_this_log_period = accelerator.gather(local_pred_tokens_this_log_period).sum().item()
local_pred_tokens_this_log_period.zero_()

avg_tokens_per_batch = (
total_tokens
/ accelerator.num_processes
/ args.per_device_train_batch_size
/ args.gradient_accumulation_steps
/ args.logging_steps
/ completed_steps
)
avg_tokens_per_batch_including_padding = (
total_tokens_including_padding
/ accelerator.num_processes
/ args.per_device_train_batch_size
/ args.gradient_accumulation_steps
/ completed_steps
)
avg_pred_tokens_per_batch = (
total_pred_tokens
/ accelerator.num_processes
/ args.per_device_train_batch_size
/ args.gradient_accumulation_steps
/ completed_steps
)
total_tokens = accelerator.gather(local_total_tokens).sum().item()
total_tokens_including_padding = accelerator.gather(total_token_including_padding).sum().item()
metrics_to_log = {
"learning_rate": lr_scheduler.get_last_lr()[0],
"train_loss": avg_loss,
"total_tokens": total_tokens,
"per_device_tps": total_tokens / accelerator.num_processes / (time.time() - start_time),
"total_tokens_including_padding": total_tokens_including_padding,
"total_pred_tokens": total_pred_tokens,
"total_tokens_this_log_period": total_tokens_this_log_period,
"avg_tokens_per_batch": avg_tokens_per_batch,
"avg_tokens_per_batch_including_padding": avg_tokens_per_batch_including_padding,
"avg_pred_tokens_per_batch": avg_pred_tokens_per_batch,
"per_device_tps": total_tokens
/ accelerator.num_processes
/ (time.time() - start_time),
"per_device_tps_including_padding": total_tokens_including_padding
/ accelerator.num_processes
/ (time.time() - start_time),
"reserved_mem_GiB": torch.cuda.max_memory_reserved(
device=torch.cuda.current_device()
)
/ 2**30,
"allocated_mem_GiB": torch.cuda.max_memory_allocated(
device=torch.cuda.current_device()
)
/ 2**30,
}

# [Loss Reporting]
#
# It is useful to handle loss-reporting for the "mean" and "sum" loss cases
# differently. Cases:
# 1) "mean" loss: `sum_loss` takes individual losses which were *averaged* over the toks in their
# sequence and sums them over all fwd passes in the logging period. We instead want the avg over
# these passes. Report avg_loss = sum_loss / total_fwd_passes, which is roughly independent of
# global batch size.
# 2) "sum" loss: `sum_loss` takes individual losses which were *summed* over the toks in their
# sequence and sums them over all fwd passes in the logging period. We want both the avg over each
# optimizer step (which scales with the global batch size) and the average loss per token (which is
# roughly independent of global batch size) . Report avg_sum_loss = sum_loss / total_optim_steps
# and avg_loss = sum_loss / total_tokens_this_log_period. The latter is roughly comparable to what
# we report in the "mean" loss case.
if args.reduce_loss == "mean":
total_fwd_passes = (
args.logging_steps * args.gradient_accumulation_steps * accelerator.num_processes
)
avg_loss = sum_loss / total_fwd_passes
metrics_to_log["train_loss"] = avg_loss
else:
avg_loss_per_total_tok = sum_loss / total_tokens_this_log_period
# The loss per pred tok is the closest analogue to what we report as the
# avg_loss in the "mean" case
avg_loss = sum_loss / pred_tokens_this_log_period
total_optim_steps = args.logging_steps * accelerator.num_processes
avg_sum_loss = sum_loss / total_optim_steps
metrics_to_log["train_sum_loss"] = avg_sum_loss
metrics_to_log["train_loss_per_total_tok"] = avg_loss_per_total_tok
metrics_to_log["train_loss_per_pred_tok"] = avg_loss

sec_per_step = (time.time() - start_time) / (completed_steps - resume_step)
steps_remaining = args.max_train_steps - completed_steps
secs_remaining = steps_remaining * sec_per_step
accelerator.print(
f"Approx. time remaining: {timedelta(seconds=secs_remaining)}. {args.max_train_steps=}, {completed_steps=}, {steps_remaining=}"
)

if args.load_balancing_loss:
avg_aux_loss = (
accelerator.gather(total_aux_loss).mean().item()
Expand All @@ -870,6 +962,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
logger.info(
f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, TPS: {total_tokens / (time.time() - start_time)}"
)
accelerator.print(f"{metrics_to_log=}")
if args.with_tracking:
accelerator.log(metrics_to_log, step=completed_steps)
total_loss = 0
Expand Down
Loading