Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 154 additions & 21 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import torch
import transformers
from accelerate import Accelerator, DataLoaderConfiguration
from accelerate.accelerator import GradientAccumulationPlugin
from accelerate.logging import get_logger
from accelerate.utils import InitProcessGroupKwargs, set_seed
from huggingface_hub import HfApi
Expand Down Expand Up @@ -87,8 +88,6 @@ class FlatArguments:

exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""The name of this experiment"""
run_name: Optional[str] = None
"""A unique name of this run"""
do_not_randomize_output_dir: bool = False
"""By default the output directory will be randomized"""
model_name_or_path: Optional[str] = field(
Expand Down Expand Up @@ -373,6 +372,18 @@ class FlatArguments:
"help": "Whether to clean up all previous checkpoints at the end of the run.",
},
)
final_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Set the final lr value at the end of training to be final_lr_ratio * learning_rate."
" Only for linear schedulers, currently."
},
)
add_seed_and_date_to_exp_name: bool = True
additional_model_arguments: Optional[list[str]] = field(
default=None,
metadata={"help": "A list of key:val to be passed as additional model args."},
)

def __post_init__(self):
if self.reduce_loss not in ["mean", "sum"]:
Expand All @@ -387,6 +398,29 @@ def __post_init__(self):
raise ValueError("Cannot provide two dataset selection mechanisms.")
if self.try_launch_beaker_eval_jobs and not self.push_to_hub:
raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.")
if self.final_lr_ratio is not None:
if self.lr_scheduler_type != "linear":
raise NotImplementedError("final_lr_ratio only currently implemented for linear schedulers")
if not (1.0 >= self.final_lr_ratio >= 0.0):
raise ValueError(f"final_lr_ratio must be between 0 and 1, not {self.final_lr_ratio=}")

if self.additional_model_arguments is not None:
import re

def maybe_convert_(x):
return (
float(x)
if x.count(".") == 1 and re.sub(r"^-?.*\.", "", x, count=1).isnumeric()
else (int(x) if x.count(".") == 0 and re.sub("^-?", "", x).isnumeric() else x)
)

try:
self.additional_model_arguments = [x.split(":") for x in self.additional_model_arguments]
self.additional_model_arguments = {k: maybe_convert_(v) for k, v in self.additional_model_arguments}
except IndexError:
raise ValueError("Malformed additional model arguments. Should be space-delimited list of key:val.")
else:
self.additional_model_arguments = {}


def main(args: FlatArguments, tc: TokenizerConfig):
Expand All @@ -401,11 +435,15 @@ def main(args: FlatArguments, tc: TokenizerConfig):
# if you get timeouts (e.g. due to long tokenization) increase this.
timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.timeout))
dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True)

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
dataloader_config=dataloader_config,
**accelerator_log_kwargs,
kwargs_handlers=[timeout_kwargs],
gradient_accumulation_plugin=GradientAccumulationPlugin(
num_steps=args.gradient_accumulation_steps,
sync_each_batch=True,
),
)

# ------------------------------------------------------------
Expand All @@ -425,9 +463,13 @@ def main(args: FlatArguments, tc: TokenizerConfig):

# ------------------------------------------------------------
# Set up runtime variables
args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}"

if args.add_seed_and_date_to_exp_name:
args.exp_name = f"{args.exp_name}__{args.seed}__{int(time.time())}"
else:
args.exp_name = args.exp_name
if not args.do_not_randomize_output_dir:
args.output_dir = os.path.join(args.output_dir, args.run_name)
args.output_dir = os.path.join(args.output_dir, args.exp_name)
logger.info("using the output directory: %s", args.output_dir)
args.dataset_local_cache_dir = os.path.abspath(args.dataset_local_cache_dir)
if is_beaker_job():
Expand All @@ -441,7 +483,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
args.hf_entity = HfApi().whoami()["name"]
args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
if args.hf_repo_revision is None:
args.hf_repo_revision = args.run_name
args.hf_repo_revision = args.exp_name
args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}"
if is_beaker_job():
beaker_config = maybe_get_beaker_config()
Expand All @@ -465,7 +507,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
experiment_config,
init_kwargs={
"wandb": {
"name": args.run_name,
"name": args.exp_name,
"entity": args.wandb_entity,
"tags": [args.exp_name] + get_wandb_tags(),
}
Expand Down Expand Up @@ -534,12 +576,14 @@ def main(args: FlatArguments, tc: TokenizerConfig):
args.config_name,
revision=args.model_revision,
trust_remote_code=tc.trust_remote_code,
**args.additional_model_arguments,
)
elif args.model_name_or_path:
config = AutoConfig.from_pretrained(
args.model_name_or_path,
revision=args.model_revision,
trust_remote_code=tc.trust_remote_code,
**args.additional_model_arguments,
)
else:
raise ValueError(
Expand Down Expand Up @@ -708,11 +752,19 @@ def main(args: FlatArguments, tc: TokenizerConfig):
num_training_steps_for_scheduler = (
args.max_train_steps if overrode_max_train_steps else args.max_train_steps * accelerator.num_processes
)

num_warmup_steps = int(num_training_steps_for_scheduler * args.warmup_ratio)
if args.final_lr_ratio is not None and args.lr_scheduler_type == "linear":
# Correct num_training_steps_for_scheduler to respect final_lr_ratio for a linear scheduler
num_training_steps_for_scheduler = (
num_training_steps_for_scheduler - args.final_lr_ratio * num_warmup_steps
) / (1 - args.final_lr_ratio)

lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_training_steps=num_training_steps_for_scheduler,
num_warmup_steps=int(num_training_steps_for_scheduler * args.warmup_ratio),
num_warmup_steps=num_warmup_steps,
)
# Prepare everything with `accelerator`.
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
Expand Down Expand Up @@ -756,7 +808,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):

if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
resume_step = 0
completed_steps = starting_epoch * num_update_steps_per_epoch
else:
# need to multiply `gradient_accumulation_steps` to reflect real steps
Expand All @@ -765,10 +817,14 @@ def main(args: FlatArguments, tc: TokenizerConfig):
completed_steps = resume_step // args.gradient_accumulation_steps
resume_step -= starting_epoch * len(train_dataloader)

else:
resume_step = 0

print(f"Starting from epoch {starting_epoch} and step {completed_steps}.")
# update the progress_bar if load from checkpoint
progress_bar.update(completed_steps)
local_total_tokens = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
local_total_tokens_this_log_period = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
total_token_including_padding = torch.tensor(0, dtype=torch.int64, device=accelerator.device)
start_time = time.time()
skipped_batches = False
Expand All @@ -777,7 +833,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
train_dataloader.set_epoch(epoch)
total_loss = 0
total_aux_loss = 0
if last_checkpoint_path and resume_step is not None and not skipped_batches:
if last_checkpoint_path and resume_step and not skipped_batches:
# We skip the first `n` batches in the dataloader when resuming from a checkpoint.
active_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
# Only perform this skip once
Expand All @@ -786,20 +842,21 @@ def main(args: FlatArguments, tc: TokenizerConfig):
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
if "attention_mask" in batch:
local_total_tokens += batch["attention_mask"].sum()
tokens_in_batch = batch["attention_mask"].sum()
total_token_including_padding += batch["attention_mask"].numel()
elif "position_ids" in batch:
tokens_in_batch = batch["position_ids"].numel()
local_total_tokens += tokens_in_batch
total_token_including_padding += tokens_in_batch
elif "cu_seq_lens_q" in batch:
tokens_in_batch = batch["cu_seq_lens_q"][-1]
local_total_tokens += tokens_in_batch
total_token_including_padding += tokens_in_batch
else:
raise ValueError(
f"Expected attention_mask or position_ids or cu_seq_lens_q in batch, found {batch=}"
)
local_total_tokens += tokens_in_batch
local_total_tokens_this_log_period += tokens_in_batch

with accelerator.accumulate(model):
if args.load_balancing_loss:
outputs = model(**batch, use_cache=False, output_router_logits=True)
Expand Down Expand Up @@ -835,7 +892,20 @@ def main(args: FlatArguments, tc: TokenizerConfig):
loss += aux_loss
# We keep track of the loss at each logged step
total_loss += loss.detach().float()
accelerator.backward(loss)

# [Pre-backwards scalings]
# accelerator.backward internally divides by `gradient_accumulation_steps`, which is
# only the right thing to do for `mean` losses. For "sum" losses, we counteract this
# by multiplying the loss by `gradient_accumulation_steps` before the backwards
# call. Additionally, DeepSpeed/FSDP average the gradients across processes, whereas
# we should be summing gradients for a "sum" loss, hence we also multiply by the
# world size in this latter case.
accelerator.backward(
loss
if args.reduce_loss == "mean"
else loss * args.gradient_accumulation_steps * accelerator.num_processes
)

if args.load_balancing_loss:
total_aux_loss += aux_loss.detach().float()
# clip gradient norm. don't do this with deepspeed
Expand All @@ -850,23 +920,85 @@ def main(args: FlatArguments, tc: TokenizerConfig):
progress_bar.update(1)
completed_steps += 1
if args.logging_steps and completed_steps % args.logging_steps == 0:
avg_loss = (
accelerator.gather(total_loss).mean().item()
sum_loss = accelerator.gather(total_loss).sum().item()
total_tokens = accelerator.gather(local_total_tokens).sum().item()
total_tokens_including_padding = (
accelerator.gather(total_token_including_padding).sum().item()
)
total_tokens_this_log_period = accelerator.gather(local_total_tokens_this_log_period).sum().item()
local_total_tokens_this_log_period.zero_()

avg_tokens_per_batch = (
total_tokens
/ accelerator.num_processes
/ args.per_device_train_batch_size
/ args.gradient_accumulation_steps
/ args.logging_steps
/ completed_steps
)
avg_tokens_per_batch_including_padding = (
total_tokens_including_padding
/ accelerator.num_processes
/ args.per_device_train_batch_size
/ args.gradient_accumulation_steps
/ completed_steps
)
total_tokens = accelerator.gather(local_total_tokens).sum().item()
total_tokens_including_padding = accelerator.gather(total_token_including_padding).sum().item()
metrics_to_log = {
"learning_rate": lr_scheduler.get_last_lr()[0],
"train_loss": avg_loss,
"total_tokens": total_tokens,
"per_device_tps": total_tokens / accelerator.num_processes / (time.time() - start_time),
"total_tokens_this_log_period": total_tokens_this_log_period,
"avg_tokens_per_batch": avg_tokens_per_batch,
"avg_tokens_per_batch_including_padding": avg_tokens_per_batch_including_padding,
"per_device_tps": total_tokens
/ accelerator.num_processes
/ (time.time() - start_time),
"total_tokens_including_padding": total_tokens_including_padding,
"per_device_tps_including_padding": total_tokens_including_padding
/ accelerator.num_processes
/ (time.time() - start_time),
"reserved_mem_GiB": torch.cuda.max_memory_reserved(
device=torch.cuda.current_device()
)
/ 2**30,
"allocated_mem_GiB": torch.cuda.max_memory_allocated(
device=torch.cuda.current_device()
)
/ 2**30,
}

# [Loss Reporting]
#
# It is useful to handle loss-reporting for the "mean" and "sum" loss cases
# differently. Cases:
# 1) "mean" loss: `sum_loss` takes individual losses which were *averaged* over the toks in their
# sequence and sums them over all fwd passes in the logging period. We instead want the avg over
# these passes. Report avg_loss = sum_loss / total_fwd_passes, which is roughly independent of
# global batch size.
# 2) "sum" loss: `sum_loss` takes individual losses which were *summed* over the toks in their
# sequence and sums them over all fwd passes in the logging period. We want both the avg over each
# optimizer step (which scales with the global batch size) and the average loss per token (which is
# roughly independent of global batch size) . Report avg_sum_loss = sum_loss / total_optim_steps
# and avg_loss = sum_loss / total_tokens_this_log_period. The latter is roughly comparable to what
# we report in the "mean" loss case.
if args.reduce_loss == "mean":
total_fwd_passes = (
args.logging_steps * args.gradient_accumulation_steps * accelerator.num_processes
)
avg_loss = sum_loss / total_fwd_passes
else:
avg_loss = sum_loss / total_tokens_this_log_period
total_optim_steps = args.logging_steps * accelerator.num_processes
avg_sum_loss = sum_loss / total_optim_steps
metrics_to_log["train_sum_loss"] = avg_sum_loss

metrics_to_log["train_loss"] = avg_loss

sec_per_step = (time.time() - start_time) / (completed_steps - resume_step)
steps_remaining = args.max_train_steps - completed_steps
secs_remaining = steps_remaining * sec_per_step
accelerator.print(
f"Approx. time remaining: {timedelta(seconds=secs_remaining)}. {args.max_train_steps=}, {completed_steps=}, {steps_remaining=}"
)

if args.load_balancing_loss:
avg_aux_loss = (
accelerator.gather(total_aux_loss).mean().item()
Expand All @@ -881,6 +1013,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
logger.info(
f" Step: {completed_steps}, LR: {lr_scheduler.get_last_lr()[0]}, Loss: {avg_loss}, TPS: {total_tokens / (time.time() - start_time)}"
)
accelerator.print(f"{metrics_to_log=}")
if args.with_tracking:
accelerator.log(
metrics_to_log,
Expand Down