Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,42 @@ def preference_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenize
return any(x != -100 for x in row[CHOSEN_LABELS_KEY]) and any(x != -100 for x in row[REJECTED_LABELS_KEY])


def preference_span_search_mask_out(
row: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
max_seq_length: int,
chosen_key: str = DEFAULT_CHOSEN_KEY,
rejected_key: str = DEFAULT_REJECTED_KEY,
):
"""
Here we assume each example has a rejected and chosen field, both of which are a list of messages.
Each message is a dict with 'role' and 'content' fields.
We assume only the last message is different, and the prompt is contained in the list of messages.
"""
chosen_messages = row[chosen_key]
rejected_messages = row[rejected_key]
if len(chosen_messages) == 0:
raise ValueError("chosen messages field is empty.")
if len(rejected_messages) == 0:
raise ValueError("rejected messages field is empty.")

chosen_encoded = sft_span_seach_mask_out(
{DEFAULT_SFT_MESSAGES_KEY: chosen_messages}, tokenizer, max_seq_length
)
rejected_encoded = sft_span_seach_mask_out(
{DEFAULT_SFT_MESSAGES_KEY: rejected_messages}, tokenizer, max_seq_length
)

return {
CHOSEN_INPUT_IDS_KEY: chosen_encoded["input_ids"],
CHOSEN_LABELS_KEY: chosen_encoded["labels"],
CHOSEN_ATTENTION_MASK_KEY: chosen_encoded["attention_mask"],
REJECTED_INPUT_IDS_KEY: rejected_encoded["input_ids"],
REJECTED_LABELS_KEY: rejected_encoded["labels"],
REJECTED_ATTENTION_MASK_KEY: rejected_encoded["attention_mask"],
}


def rlvr_tokenize_v1(
row: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
Expand Down Expand Up @@ -1064,6 +1100,7 @@ def rlvr_filter_v1(
"preference_tokenize_v1": (preference_tokenize_v1, "map"),
"preference_filter_v1": (preference_filter_v1, "filter"),
"preference_tulu_tokenize_and_truncate_v1": (preference_tulu_tokenize_and_truncate_v1_2, "map"),
"preference_span_search_mask_out": (preference_span_search_mask_out, "map"),
"preference_tulu_filter_v1": (preference_tulu_filter_v1, "filter"),
"rlvr_tokenize_v1": (rlvr_tokenize_v2, "map"),
"rlvr_filter_v1": (rlvr_filter_v1, "filter"),
Expand Down Expand Up @@ -1130,6 +1167,9 @@ def __post_init__(self):
if os.path.exists(self.dataset_name) and self.dataset_name.endswith(".jsonl"):
assert self.dataset_split == "train", "Only train split is supported for local jsonl files."
self.dataset = load_dataset("json", data_files=self.dataset_name, split=self.dataset_split)
elif os.path.exists(self.dataset_name) and self.dataset_name.endswith(".parquet"):
assert self.dataset_split == "train", "Only train split is supported for local parquet files."
self.dataset = load_dataset("parquet", data_files=self.dataset_name, split=self.dataset_split)
else:
# commit hash only works for hf datasets
self.dataset_commit_hash = get_commit_hash(
Expand Down
103 changes: 90 additions & 13 deletions open_instruct/dpo_tune_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
except Exception:
pass
# isort: on
import json
import logging
import math
import os
Expand All @@ -42,11 +43,13 @@
from typing import Callable, List, Literal, Optional, Union

import datasets
from functools import partial
import torch
import torch.utils
import torch.utils.data
import transformers
from accelerate import Accelerator, DataLoaderConfiguration
from accelerate.accelerator import GradientAccumulationPlugin
from accelerate.logging import get_logger
from accelerate.utils import InitProcessGroupKwargs, set_seed
from huggingface_hub import HfApi
Expand All @@ -55,6 +58,7 @@
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig, get_scheduler
from transformers.training_args import _convert_str_dict

from open_instruct.dataset_transformation import (
CHOSEN_INPUT_IDS_KEY,
Expand All @@ -71,6 +75,9 @@
simpo_loss,
wpo_loss,
)
from open_instruct.padding_free_collator import (
TensorDataCollatorWithFlatteningDPO
)
from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate
from open_instruct.utils import (
ArgumentParserPlus,
Expand All @@ -92,11 +99,22 @@ class FlatArguments:
"""
Full arguments class for all fine-tuning jobs.
"""
# Sometimes users will pass in a `str` repr of a dict in the CLI
# We need to track what fields those can be. Each time a new arg
# has a dict type, it must be added to this list.
# Important: These should be typed with Optional[Union[dict,str,...]]
_VALID_DICT_FIELDS = [
"additional_model_arguments",
]

exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""The name of this experiment"""
run_name: Optional[str] = None
"""A unique name of this run"""
add_seed_and_date_to_exp_name: bool = True
"""Append the seed and date to exp_name"""
do_not_randomize_output_dir: bool = False
"""By default the output directory will be randomized"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -132,6 +150,11 @@ class FlatArguments:
default=None,
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
additional_model_arguments: Optional[Union[dict, str]] = field(
default_factory=dict, metadata={"help": "A dictionary of additional model args used to construct the model."}
)
sync_each_batch: bool = False
"""Optionaly sync grads every batch when using grad accumulation. Can significantly reduce memory costs."""
low_cpu_mem_usage: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -323,6 +346,13 @@ class FlatArguments:
cache_dataset_only: bool = False
"""Immediately exit after caching the dataset"""

packing: bool = field(
default=False,
metadata={
"help": "Whether to use packing/padding-free collation via DataCollatorWithFlatteningDPO"
},
)

# Ai2 specific settings
try_auto_save_to_beaker: bool = True
"""Whether to try to save the model to Beaker dataset `/output` after training"""
Expand All @@ -346,6 +376,17 @@ def __post_init__(self):
raise ValueError("Cannot provide two dataset selection mechanisms.")
if self.try_launch_beaker_eval_jobs and not self.push_to_hub:
raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.")

# Parse in args that could be `dict` sent in from the CLI as a string
for dict_feld in self._VALID_DICT_FIELDS:
passed_value = getattr(self, dict_feld)
# We only want to do this if the str starts with a bracket to indicate a `dict`
# else its likely a filename if supported
if isinstance(passed_value, str) and passed_value.startswith("{"):
loaded_dict = json.loads(passed_value)
# Convert str values to types if applicable
loaded_dict = _convert_str_dict(loaded_dict)
setattr(self, dict_feld, loaded_dict)


def get_cache_ref_logprobs(
Expand All @@ -368,7 +409,10 @@ def get_cache_ref_logprobs(
cached_reference_chosen_logps = []
cached_reference_rejected_logps = []
with torch.no_grad():
for step, batch in tqdm(enumerate(active_dataloader), disable=not accelerator.is_local_main_process):
for batch in tqdm(
active_dataloader, disable=not accelerator.is_local_main_process,
desc=f'Generating reference cache (epoch {epoch})'
):
if args.use_lora:
with accelerator.unwrap_model(model).disable_adapter():
reference_chosen_logps, reference_rejected_logps, _ = forward_fn(
Expand Down Expand Up @@ -398,10 +442,13 @@ def main(args: FlatArguments, tc: TokenizerConfig):
timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.timeout))
dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
dataloader_config=dataloader_config,
**accelerator_log_kwargs,
kwargs_handlers=[timeout_kwargs],
gradient_accumulation_plugin=GradientAccumulationPlugin(
num_steps=args.gradient_accumulation_steps,
sync_each_batch=args.sync_each_batch,
),
)

# ------------------------------------------------------------
Expand All @@ -421,8 +468,13 @@ def main(args: FlatArguments, tc: TokenizerConfig):

# ------------------------------------------------------------
# Set up runtime variables
args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}"
args.output_dir = os.path.join(args.output_dir, args.run_name)
if args.add_seed_and_date_to_exp_name:
args.exp_name = f"{args.exp_name}__{args.seed}__{int(time.time())}"
else:
args.exp_name = args.exp_name
if not args.do_not_randomize_output_dir:
args.output_dir = os.path.join(args.output_dir, args.exp_name)
logger.info("using the output directory: %s", args.output_dir)
args.dataset_local_cache_dir = os.path.abspath(args.dataset_local_cache_dir)
if is_beaker_job():
args.dataset_local_cache_dir = "/weka/oe-adapt-default/allennlp/deletable_open_instruct_dataset_cache"
Expand All @@ -435,7 +487,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
args.hf_entity = HfApi().whoami()["name"]
args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
if args.hf_repo_revision is None:
args.hf_repo_revision = args.run_name
args.hf_repo_revision = args.exp_name
args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}"
if is_beaker_job():
beaker_config = maybe_get_beaker_config()
Expand All @@ -459,7 +511,7 @@ def main(args: FlatArguments, tc: TokenizerConfig):
experiment_config,
init_kwargs={
"wandb": {
"name": args.run_name,
"name": args.exp_name,
"entity": args.wandb_entity,
"tags": [args.exp_name] + get_wandb_tags(),
}
Expand Down Expand Up @@ -524,11 +576,17 @@ def main(args: FlatArguments, tc: TokenizerConfig):
# Load pretrained model and tokenizer
if args.config_name:
config = AutoConfig.from_pretrained(
args.config_name, revision=args.model_revision, trust_remote_code=tc.trust_remote_code
args.config_name,
revision=args.model_revision,
trust_remote_code=tc.trust_remote_code,
**args.additional_model_arguments,
)
elif args.model_name_or_path:
config = AutoConfig.from_pretrained(
args.model_name_or_path, revision=args.model_revision, trust_remote_code=tc.trust_remote_code
args.model_name_or_path,
revision=args.model_revision,
trust_remote_code=tc.trust_remote_code,
**args.additional_model_arguments,
)
else:
raise ValueError(
Expand All @@ -555,7 +613,7 @@ def load_model():
quantization_config=bnb_config,
device_map=device_map,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True if args.use_flash_attn else False,
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
)
elif args.use_liger_kernel:
from liger_kernel.transformers import AutoLigerKernelForCausalLM
Expand All @@ -582,7 +640,8 @@ def load_model():
config=config,
trust_remote_code=tc.trust_remote_code,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_flash_attention_2=True if args.use_flash_attn else False,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
)
else:
logger.info("Training new model from scratch")
Expand All @@ -598,8 +657,10 @@ def load_model():
# gather deepspeed to get "real" embedding size
embeddings = model.get_input_embeddings()
with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None):
if len(tokenizer) > embeddings.weight.shape[0]:
model.resize_token_embeddings(len(tokenizer))
embedding_size = embeddings.weight.shape[0]

if len(tokenizer) > embedding_size:
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)

if args.use_lora:
if args.use_qlora:
Expand Down Expand Up @@ -630,10 +691,20 @@ def load_model():
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

# DataLoaders creation:
if args.packing:
accelerator.print("Using packing/padding-free collation")
collate_fn = TensorDataCollatorWithFlatteningDPO(
return_position_ids=True, return_flash_attn_kwargs=True
)
else:
collate_fn = DataCollatorForSeq2SeqDPO(
tokenizer=tokenizer, model=model, padding="longest"
)

train_dataloader = DataLoader(
train_dataset,
shuffle=True,
collate_fn=DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=model, padding="longest"),
collate_fn=collate_fn,
batch_size=args.per_device_train_batch_size,
)

Expand Down Expand Up @@ -750,6 +821,12 @@ def load_model():
average_log_prob_loss_types = ["simpo", "dpo_norm"]
average_log_prob = args.dpo_loss_type in average_log_prob_loss_types
forward_fn = concatenated_forward if args.concatenated_forward else separate_forward
if args.packing:
if not args.concatenated_forward:
raise NotImplementedError(
"seperate forward not implemented for packing/padding-free"
)
forward_fn = partial(forward_fn, packing=True)
if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm":
epoch_cached_reference_chosen_logps, epoch_cached_reference_rejected_logps = get_cache_ref_logprobs(
model,
Expand Down
38 changes: 30 additions & 8 deletions open_instruct/dpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
from transformers import DataCollatorForSeq2Seq

from open_instruct.model_utils import log_softmax_and_gather
from open_instruct.padding_free_collator import (
concatenated_inputs as pf_concatenated_inputs,
get_batch_logps as pf_get_batch_logps
)

torch.backends.cuda.matmul.allow_tf32 = True

Expand Down Expand Up @@ -225,29 +229,47 @@ def concatenated_forward(
batch: Dict[str, Union[List, torch.LongTensor]],
average_log_prob: bool = False,
output_router_logits: bool = False,
packing: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

We do this to avoid doing two forward passes, because it's faster for FSDP.
"""
concatenated_batch = concatenated_inputs(batch)
if not packing:
concatenated_batch = concatenated_inputs(batch)
else:
concatenated_batch, bs = pf_concatenated_inputs(batch)

inputs = {
k.replace("concatenated_", ""):v
for k, v in concatenated_batch.items() if k.startswith("concatenated_") and not k.endswith("labels")

}
if output_router_logits:
outputs = model(
input_ids=concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
**inputs,
output_router_logits=True,
)
logits = outputs.logits.to(torch.float32)
aux_loss = outputs.aux_loss
else:
logits = model(
input_ids=concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
**inputs
).logits.to(torch.float32)
aux_loss = None
all_logps = _get_batch_logps(logits, concatenated_batch["concatenated_labels"], average_log_prob=average_log_prob)
chosen_logps = all_logps[: batch["chosen_input_ids"].shape[0]]
rejected_logps = all_logps[batch["chosen_input_ids"].shape[0] :]

if not packing:
all_logps = _get_batch_logps(logits, concatenated_batch["concatenated_labels"], average_log_prob=average_log_prob)
bs = batch["chosen_input_ids"].shape[0]
else:
all_logps = pf_get_batch_logps(
logits,
concatenated_batch["concatenated_labels"],
inputs['cu_seq_lens_k'], # assume same as q
average_log_prob=average_log_prob
)
chosen_logps = all_logps[:bs]
rejected_logps = all_logps[bs:]
return chosen_logps, rejected_logps, aux_loss


Expand Down
Loading