Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 182 additions & 18 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@
import torch
import transformers
from accelerate import Accelerator, DataLoaderConfiguration
from accelerate.accelerator import GradientAccumulationPlugin
from accelerate.logging import get_logger
from accelerate.utils import InitProcessGroupKwargs, set_seed
from huggingface_hub import HfApi
from padding_free_collator import TensorDataCollatorWithFlattening
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from rich.pretty import pprint
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -362,6 +364,28 @@ class FlatArguments:
oe_eval_max_length: int = 4096
"""the max generation length for evaluation for oe-eval"""

padding_free: bool = field(
default=False,
metadata={"help": "Whether to use padding-free collation via TensorDataCollatorWithFlattening"},
)
clean_checkpoints_at_end: bool = field(
default=False,
metadata={
"help": "Whether to clean up all previous checkpoints at the end of the run.",
},
)
final_lr_ratio: float = field(
default=0.1,
metadata={
"help": "Set the final lr value at the end of training to be final_lr_ratio * learning_rate."
},
)
add_seed_and_date_to_run_name: bool = True
additional_model_arguments: Optional[list[str]] = field(
default=None,
metadata={"help": "A list of key:val to be passed as additional model args."},
)

def __post_init__(self):
if self.reduce_loss not in ["mean", "sum"]:
raise ValueError("reduce_loss must be either 'mean' or 'sum'")
Expand All @@ -376,6 +400,24 @@ def __post_init__(self):
if self.try_launch_beaker_eval_jobs and not self.push_to_hub:
raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.")

if self.additional_model_arguments is not None:
import re

def maybe_convert_(x):
return (
float(x)
if x.count(".") == 1 and re.sub(r"^-?.*\.", "", x, count=1).isnumeric()
else (int(x) if x.count(".") == 0 and re.sub("^-?", "", x).isnumeric() else x)
)

try:
self.additional_model_arguments = [x.split(":") for x in self.additional_model_arguments]
self.additional_model_arguments = {k: maybe_convert_(v) for k, v in self.additional_model_arguments}
except IndexError:
raise ValueError("Malformed additional model arguments. Should be space-delimited list of key:val.")
else:
self.additional_model_arguments = {}


def main(args: FlatArguments, tc: TokenizerConfig):
# ------------------------------------------------------------
Expand All @@ -389,11 +431,15 @@ def main(args: FlatArguments, tc: TokenizerConfig):
# if you get timeouts (e.g. due to long tokenization) increase this.
timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.timeout))
dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True)

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
dataloader_config=dataloader_config,
**accelerator_log_kwargs,
kwargs_handlers=[timeout_kwargs],
gradient_accumulation_plugin=GradientAccumulationPlugin(
num_steps=args.gradient_accumulation_steps,
sync_each_batch=True,
),
)

# ------------------------------------------------------------
Expand All @@ -413,7 +459,13 @@ def main(args: FlatArguments, tc: TokenizerConfig):

# ------------------------------------------------------------
# Set up runtime variables
args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}"

# TODO: @goon - remove run_name from FlatArguments; it's just ignored and overwritten with
# exp_name
if args.add_seed_and_date_to_run_name:
args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}"
else:
args.run_name = args.exp_name
if not args.do_not_randomize_output_dir:
args.output_dir = os.path.join(args.output_dir, args.run_name)
logger.info("using the output directory: %s", args.output_dir)
Expand Down Expand Up @@ -490,7 +542,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):

if args.dataset_mixer is not None:
args.dataset_mixer_list = [item for pair in args.dataset_mixer.items() for item in pair]
with accelerator.main_process_first():
with accelerator.local_main_process_first():
transform_fn_args = [
{"max_seq_length": args.max_seq_length},
{},
Expand Down Expand Up @@ -522,12 +574,14 @@ def main(args: FlatArguments, tc: TokenizerConfig):
args.config_name,
revision=args.model_revision,
trust_remote_code=tc.trust_remote_code,
**args.additional_model_arguments,
)
elif args.model_name_or_path:
config = AutoConfig.from_pretrained(
args.model_name_or_path,
revision=args.model_revision,
trust_remote_code=tc.trust_remote_code,
**args.additional_model_arguments,
)
else:
raise ValueError(
Expand Down Expand Up @@ -630,10 +684,18 @@ def main(args: FlatArguments, tc: TokenizerConfig):
model.gradient_checkpointing_enable()

# DataLoaders creation:
if args.padding_free:
collate_fn = TensorDataCollatorWithFlattening()
else:
collate_fn = DataCollatorForSeq2Seq(
tokenizer=tokenizer, model=model, padding="longest"
)

accelerator.print("Creating dataloader")
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"),
collate_fn=collate_fn,
batch_size=args.per_device_train_batch_size,
)

Expand Down Expand Up @@ -688,11 +750,24 @@ def main(args: FlatArguments, tc: TokenizerConfig):
num_training_steps_for_scheduler = (
args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes
)

# HACK: @goon - final_lr_ratio assumes a linear scheduler, so putting this assert in to sanity
# check.
assert args.lr_scheduler_type == "linear"
assert 1.0 > args.final_lr_ratio >= 0.0

# HACK: @goon - adjust num_training_steps_for_scheduler so that the final LR is learning_rate *
# final_lr_ratio.
num_warmup_steps = int(num_training_steps_for_scheduler * args.warmup_ratio)
num_training_steps_for_scheduler = (
num_training_steps_for_scheduler - args.final_lr_ratio * num_warmup_steps
) / (1 - args.final_lr_ratio)

lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_training_steps=num_training_steps_for_scheduler,
num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio),
num_warmup_steps=num_warmup_steps,
)
# Prepare everything with `accelerator`.
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
Expand Down Expand Up @@ -736,7 +811,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):

if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
resume_step = 0
completed_steps = starting_epoch * num_update_steps_per_epoch
else:
# need to multiply `gradient_accumulation_steps` to reflect real steps
Expand All @@ -745,10 +820,14 @@ def main(args: FlatArguments, tc: TokenizerConfig):
completed_steps = resume_step // args.gradient_accumulation_steps
resume_step -= starting_epoch * len(train_dataloader)

else:
resume_step = 0

print(f"Starting from epoch {starting_epoch} and step {completed_steps}.")
# update the progress_bar if load from checkpoint
progress_bar.update(completed_steps)
local_total_tokens = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
local_total_tokens_this_log_period = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
total_token_including_padding = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
start_time = time.time()
skipped_batches = False
Expand All @@ -757,16 +836,30 @@ def main(args: FlatArguments, tc: TokenizerConfig):
train_dataloader.set_epoch(epoch)
total_loss = 0
total_aux_loss = 0
if last_checkpoint_path and resume_step is not None and not skipped_batches:
if last_checkpoint_path and resume_step and not skipped_batches:
# We skip the first `n` batches in the dataloader when resuming from a checkpoint.
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
# Only perform this skip once
skipped_batches = True
else:
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
local_total_tokens += batch["attention_mask"].sum()
total_token_including_padding += batch["attention_mask"].numel()
if "attention_mask" in batch:
tokens_in_batch = batch["attention_mask"].sum()
total_token_including_padding += batch["attention_mask"].numel()
elif "position_ids" in batch:
tokens_in_batch = batch["position_ids"].numel()
total_token_including_padding += tokens_in_batch
elif "cu_seq_lens_q" in batch:
tokens_in_batch = batch["cu_seq_lens_q"][-1]
total_token_including_padding += tokens_in_batch
else:
raise ValueError(
f"Expected attention_mask or position_ids or cu_seq_lens_q in batch, found {batch=}"
)
local_total_tokens += tokens_in_batch
local_total_tokens_this_log_period += tokens_in_batch

with accelerator.accumulate(model):
if args.load_balancing_loss:
outputs = model(**batch, use_cache=False, output_router_logits=True)
Expand Down Expand Up @@ -802,7 +895,15 @@ def main(args: FlatArguments, tc: TokenizerConfig):
loss += aux_loss
# We keep track of the loss at each logged step
total_loss += loss.detach().float()
accelerator.backward(loss)
# accelerator.backward internally divides by `gradient_accumulation_steps`, which is
# only the right thing to do for `mean` losses, so we counteract this operation for
# "sum" losses. Strictly speaking, we should also multiply by the world size to
# turn DeepSpeed/FSDP grad-averaging into grad-summing (TODO).
accelerator.backward(
loss
if args.reduce_loss == "mean"
else loss * args.gradient_accumulation_steps
)
if args.load_balancing_loss:
total_aux_loss += aux_loss.detach().float()
# clip gradient norm. don't do this with deepspeed
Expand All @@ -817,23 +918,85 @@ def main(args: FlatArguments, tc: TokenizerConfig):
progress_bar.update(1)
completed_steps += 1
if args.logging_steps and completed_steps % args.logging_steps == 0:
avg_loss = (
accelerator.gather(total_loss).mean().item()
sum_loss = accelerator.gather(total_loss).sum().item()
total_tokens = accelerator.gather(local_total_tokens).sum().item()
total_tokens_including_padding = (
accelerator.gather(total_token_including_padding).sum().item()
)
total_tokens_this_log_period = accelerator.gather(local_total_tokens_this_log_period).sum().item()
local_total_tokens_this_log_period.zero_()

avg_tokens_per_batch = (
total_tokens
/ accelerator.num_processes
/ args.per_device_train_batch_size
/ args.gradient_accumulation_steps
/ args.logging_steps
/ completed_steps
)
avg_tokens_per_batch_including_padding = (
total_tokens_including_padding
/ accelerator.num_processes
/ args.per_device_train_batch_size
/ args.gradient_accumulation_steps
/ completed_steps
)
total_tokens = accelerator.gather(local_total_tokens).sum().item()
total_tokens_including_padding = accelerator.gather(total_token_including_padding).sum().item()
metrics_to_log = {
"learning_rate": lr_scheduler.get_last_lr()[0],
"train_loss": avg_loss,
"total_tokens": total_tokens,
"per_device_tps": total_tokens / accelerator.num_processes / (time.time() - start_time),
"total_tokens_this_log_period": total_tokens_this_log_period,
"avg_tokens_per_batch": avg_tokens_per_batch,
"avg_tokens_per_batch_including_padding": avg_tokens_per_batch_including_padding,
"per_device_tps": total_tokens
/ accelerator.num_processes
/ (time.time() - start_time),
"total_tokens_including_padding": total_tokens_including_padding,
"per_device_tps_including_padding": total_tokens_including_padding
/ accelerator.num_processes
/ (time.time() - start_time),
"reserved_mem_GiB": torch.cuda.max_memory_reserved(
device=torch.cuda.current_device()
)
/ 2**30,
"allocated_mem_GiB": torch.cuda.max_memory_allocated(
device=torch.cuda.current_device()
)
/ 2**30,
}

# [Loss Reporting]
#
# It is useful to handle loss-reporting for the "mean" and "sum" loss cases
# differently. Cases:
# 1) "mean" loss: `sum_loss` takes individual losses which were *averaged* over the toks in their
# sequence and sums them over all fwd passes in the logging period. We instead want the avg over
# these passes. Report avg_loss = sum_loss / total_fwd_passes, which is roughly independent of
# global batch size.
# 2) "sum" loss: `sum_loss` takes individual losses which were *summed* over the toks in their
# sequence and sums them over all fwd passes in the logging period. We want both the avg over each
# optimizer step (which scales with the global batch size) and the average loss per token (which is
# roughly independent of global batch size) . Report avg_sum_loss = sum_loss / total_optim_steps
# and avg_loss = sum_loss / total_tokens_this_log_period. The latter is roughly comparable to what
# we report in the "mean" loss case.
if args.reduce_loss == "mean":
total_fwd_passes = (
args.logging_steps * args.gradient_accumulation_steps * accelerator.num_processes
)
avg_loss = sum_loss / total_fwd_passes
else:
avg_loss = sum_loss / total_tokens_this_log_period
total_optim_steps = args.logging_steps * accelerator.num_processes
avg_sum_loss = sum_loss / total_optim_steps
metrics_to_log["train_sum_loss"] = avg_sum_loss

metrics_to_log["train_loss"] = avg_loss

sec_per_step = (time.time() - start_time) / (completed_steps - resume_step)
steps_remaining = args.max_train_steps - completed_steps
secs_remaining = steps_remaining * sec_per_step
accelerator.print(
f"Approx. time remaining: {timedelta(seconds=secs_remaining)}. {args.max_train_steps=}, {completed_steps=}, {steps_remaining=}"
)

if args.load_balancing_loss:
avg_aux_loss = (
accelerator.gather(total_aux_loss).mean().item()
Expand All @@ -848,6 +1011,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
logger.info(
f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, TPS: {total_tokens / (time.time() - start_time)}"
)
accelerator.print(f"{metrics_to_log=}")
if args.with_tracking:
accelerator.log(
metrics_to_log,
Expand Down Expand Up @@ -905,7 +1069,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
)

# remove all checkpoints to save space
if accelerator.is_local_main_process:
if args.clean_checkpoints_at_end and accelerator.is_local_main_process:
clean_last_n_checkpoints(args.output_dir, keep_last_n_checkpoints=0)

if (
Expand Down
Loading