diff --git a/open_instruct/dataset_transformation.py b/open_instruct/dataset_transformation.py index 46e9651190..f179a3f56e 100644 --- a/open_instruct/dataset_transformation.py +++ b/open_instruct/dataset_transformation.py @@ -979,6 +979,42 @@ def preference_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenize return any(x != -100 for x in row[CHOSEN_LABELS_KEY]) and any(x != -100 for x in row[REJECTED_LABELS_KEY]) +def preference_span_search_mask_out( + row: Dict[str, Any], + tokenizer: PreTrainedTokenizer, + max_seq_length: int, + chosen_key: str = DEFAULT_CHOSEN_KEY, + rejected_key: str = DEFAULT_REJECTED_KEY, +): + """ + Here we assume each example has a rejected and chosen field, both of which are a list of messages. + Each message is a dict with 'role' and 'content' fields. + We assume only the last message is different, and the prompt is contained in the list of messages. + """ + chosen_messages = row[chosen_key] + rejected_messages = row[rejected_key] + if len(chosen_messages) == 0: + raise ValueError("chosen messages field is empty.") + if len(rejected_messages) == 0: + raise ValueError("rejected messages field is empty.") + + chosen_encoded = sft_span_seach_mask_out( + {DEFAULT_SFT_MESSAGES_KEY: chosen_messages}, tokenizer, max_seq_length + ) + rejected_encoded = sft_span_seach_mask_out( + {DEFAULT_SFT_MESSAGES_KEY: rejected_messages}, tokenizer, max_seq_length + ) + + return { + CHOSEN_INPUT_IDS_KEY: chosen_encoded["input_ids"], + CHOSEN_LABELS_KEY: chosen_encoded["labels"], + CHOSEN_ATTENTION_MASK_KEY: chosen_encoded["attention_mask"], + REJECTED_INPUT_IDS_KEY: rejected_encoded["input_ids"], + REJECTED_LABELS_KEY: rejected_encoded["labels"], + REJECTED_ATTENTION_MASK_KEY: rejected_encoded["attention_mask"], + } + + def rlvr_tokenize_v1( row: Dict[str, Any], tokenizer: PreTrainedTokenizer, @@ -1064,6 +1100,7 @@ def rlvr_filter_v1( "preference_tokenize_v1": (preference_tokenize_v1, "map"), "preference_filter_v1": (preference_filter_v1, "filter"), "preference_tulu_tokenize_and_truncate_v1": (preference_tulu_tokenize_and_truncate_v1_2, "map"), + "preference_span_search_mask_out": (preference_span_search_mask_out, "map"), "preference_tulu_filter_v1": (preference_tulu_filter_v1, "filter"), "rlvr_tokenize_v1": (rlvr_tokenize_v2, "map"), "rlvr_filter_v1": (rlvr_filter_v1, "filter"), @@ -1130,6 +1167,9 @@ def __post_init__(self): if os.path.exists(self.dataset_name) and self.dataset_name.endswith(".jsonl"): assert self.dataset_split == "train", "Only train split is supported for local jsonl files." self.dataset = load_dataset("json", data_files=self.dataset_name, split=self.dataset_split) + elif os.path.exists(self.dataset_name) and self.dataset_name.endswith(".parquet"): + assert self.dataset_split == "train", "Only train split is supported for local parquet files." + self.dataset = load_dataset("parquet", data_files=self.dataset_name, split=self.dataset_split) else: # commit hash only works for hf datasets self.dataset_commit_hash = get_commit_hash( diff --git a/open_instruct/dpo_tune_cache.py b/open_instruct/dpo_tune_cache.py index a2ed8fabd3..3762bea174 100644 --- a/open_instruct/dpo_tune_cache.py +++ b/open_instruct/dpo_tune_cache.py @@ -31,6 +31,7 @@ except Exception: pass # isort: on +import json import logging import math import os @@ -42,11 +43,13 @@ from typing import Callable, List, Literal, Optional, Union import datasets +from functools import partial import torch import torch.utils import torch.utils.data import transformers from accelerate import Accelerator, DataLoaderConfiguration +from accelerate.accelerator import GradientAccumulationPlugin from accelerate.logging import get_logger from accelerate.utils import InitProcessGroupKwargs, set_seed from huggingface_hub import HfApi @@ -55,6 +58,7 @@ from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig, get_scheduler +from transformers.training_args import _convert_str_dict from open_instruct.dataset_transformation import ( CHOSEN_INPUT_IDS_KEY, @@ -71,6 +75,9 @@ simpo_loss, wpo_loss, ) +from open_instruct.padding_free_collator import ( + TensorDataCollatorWithFlatteningDPO +) from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.utils import ( ArgumentParserPlus, @@ -92,11 +99,22 @@ class FlatArguments: """ Full arguments class for all fine-tuning jobs. """ + # Sometimes users will pass in a `str` repr of a dict in the CLI + # We need to track what fields those can be. Each time a new arg + # has a dict type, it must be added to this list. + # Important: These should be typed with Optional[Union[dict,str,...]] + _VALID_DICT_FIELDS = [ + "additional_model_arguments", + ] exp_name: str = os.path.basename(__file__)[: -len(".py")] """The name of this experiment""" run_name: Optional[str] = None """A unique name of this run""" + add_seed_and_date_to_exp_name: bool = True + """Append the seed and date to exp_name""" + do_not_randomize_output_dir: bool = False + """By default the output directory will be randomized""" model_name_or_path: Optional[str] = field( default=None, metadata={ @@ -132,6 +150,11 @@ class FlatArguments: default=None, metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, ) + additional_model_arguments: Optional[Union[dict, str]] = field( + default_factory=dict, metadata={"help": "A dictionary of additional model args used to construct the model."} + ) + sync_each_batch: bool = False + """Optionaly sync grads every batch when using grad accumulation. Can significantly reduce memory costs.""" low_cpu_mem_usage: bool = field( default=False, metadata={ @@ -323,6 +346,13 @@ class FlatArguments: cache_dataset_only: bool = False """Immediately exit after caching the dataset""" + packing: bool = field( + default=False, + metadata={ + "help": "Whether to use packing/padding-free collation via DataCollatorWithFlatteningDPO" + }, + ) + # Ai2 specific settings try_auto_save_to_beaker: bool = True """Whether to try to save the model to Beaker dataset `/output` after training""" @@ -346,6 +376,17 @@ def __post_init__(self): raise ValueError("Cannot provide two dataset selection mechanisms.") if self.try_launch_beaker_eval_jobs and not self.push_to_hub: raise ValueError("Cannot launch Beaker evaluation jobs without pushing to the Hub.") + + # Parse in args that could be `dict` sent in from the CLI as a string + for dict_feld in self._VALID_DICT_FIELDS: + passed_value = getattr(self, dict_feld) + # We only want to do this if the str starts with a bracket to indicate a `dict` + # else its likely a filename if supported + if isinstance(passed_value, str) and passed_value.startswith("{"): + loaded_dict = json.loads(passed_value) + # Convert str values to types if applicable + loaded_dict = _convert_str_dict(loaded_dict) + setattr(self, dict_feld, loaded_dict) def get_cache_ref_logprobs( @@ -368,7 +409,10 @@ def get_cache_ref_logprobs( cached_reference_chosen_logps = [] cached_reference_rejected_logps = [] with torch.no_grad(): - for step, batch in tqdm(enumerate(active_dataloader), disable=not accelerator.is_local_main_process): + for batch in tqdm( + active_dataloader, disable=not accelerator.is_local_main_process, + desc=f'Generating reference cache (epoch {epoch})' + ): if args.use_lora: with accelerator.unwrap_model(model).disable_adapter(): reference_chosen_logps, reference_rejected_logps, _ = forward_fn( @@ -398,10 +442,13 @@ def main(args: FlatArguments, tc: TokenizerConfig): timeout_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=args.timeout)) dataloader_config = DataLoaderConfiguration(use_seedable_sampler=True) accelerator = Accelerator( - gradient_accumulation_steps=args.gradient_accumulation_steps, dataloader_config=dataloader_config, **accelerator_log_kwargs, kwargs_handlers=[timeout_kwargs], + gradient_accumulation_plugin=GradientAccumulationPlugin( + num_steps=args.gradient_accumulation_steps, + sync_each_batch=args.sync_each_batch, + ), ) # ------------------------------------------------------------ @@ -421,8 +468,13 @@ def main(args: FlatArguments, tc: TokenizerConfig): # ------------------------------------------------------------ # Set up runtime variables - args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - args.output_dir = os.path.join(args.output_dir, args.run_name) + if args.add_seed_and_date_to_exp_name: + args.exp_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + else: + args.exp_name = args.exp_name + if not args.do_not_randomize_output_dir: + args.output_dir = os.path.join(args.output_dir, args.exp_name) + logger.info("using the output directory: %s", args.output_dir) args.dataset_local_cache_dir = os.path.abspath(args.dataset_local_cache_dir) if is_beaker_job(): args.dataset_local_cache_dir = "/weka/oe-adapt-default/allennlp/deletable_open_instruct_dataset_cache" @@ -435,7 +487,7 @@ def main(args: FlatArguments, tc: TokenizerConfig): args.hf_entity = HfApi().whoami()["name"] args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" if args.hf_repo_revision is None: - args.hf_repo_revision = args.run_name + args.hf_repo_revision = args.exp_name args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" if is_beaker_job(): beaker_config = maybe_get_beaker_config() @@ -459,7 +511,7 @@ def main(args: FlatArguments, tc: TokenizerConfig): experiment_config, init_kwargs={ "wandb": { - "name": args.run_name, + "name": args.exp_name, "entity": args.wandb_entity, "tags": [args.exp_name] + get_wandb_tags(), } @@ -524,11 +576,17 @@ def main(args: FlatArguments, tc: TokenizerConfig): # Load pretrained model and tokenizer if args.config_name: config = AutoConfig.from_pretrained( - args.config_name, revision=args.model_revision, trust_remote_code=tc.trust_remote_code + args.config_name, + revision=args.model_revision, + trust_remote_code=tc.trust_remote_code, + **args.additional_model_arguments, ) elif args.model_name_or_path: config = AutoConfig.from_pretrained( - args.model_name_or_path, revision=args.model_revision, trust_remote_code=tc.trust_remote_code + args.model_name_or_path, + revision=args.model_revision, + trust_remote_code=tc.trust_remote_code, + **args.additional_model_arguments, ) else: raise ValueError( @@ -555,7 +613,7 @@ def load_model(): quantization_config=bnb_config, device_map=device_map, torch_dtype=torch.bfloat16, - use_flash_attention_2=True if args.use_flash_attn else False, + attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", ) elif args.use_liger_kernel: from liger_kernel.transformers import AutoLigerKernelForCausalLM @@ -582,7 +640,8 @@ def load_model(): config=config, trust_remote_code=tc.trust_remote_code, low_cpu_mem_usage=args.low_cpu_mem_usage, - use_flash_attention_2=True if args.use_flash_attn else False, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", ) else: logger.info("Training new model from scratch") @@ -598,8 +657,10 @@ def load_model(): # gather deepspeed to get "real" embedding size embeddings = model.get_input_embeddings() with deepspeed.zero.GatheredParameters(embeddings.weight, modifier_rank=None): - if len(tokenizer) > embeddings.weight.shape[0]: - model.resize_token_embeddings(len(tokenizer)) + embedding_size = embeddings.weight.shape[0] + + if len(tokenizer) > embedding_size: + model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8) if args.use_lora: if args.use_qlora: @@ -630,10 +691,20 @@ def load_model(): logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # DataLoaders creation: + if args.packing: + accelerator.print("Using packing/padding-free collation") + collate_fn = TensorDataCollatorWithFlatteningDPO( + return_position_ids=True, return_flash_attn_kwargs=True + ) + else: + collate_fn = DataCollatorForSeq2SeqDPO( + tokenizer=tokenizer, model=model, padding="longest" + ) + train_dataloader = DataLoader( train_dataset, shuffle=True, - collate_fn=DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=model, padding="longest"), + collate_fn=collate_fn, batch_size=args.per_device_train_batch_size, ) @@ -750,6 +821,12 @@ def load_model(): average_log_prob_loss_types = ["simpo", "dpo_norm"] average_log_prob = args.dpo_loss_type in average_log_prob_loss_types forward_fn = concatenated_forward if args.concatenated_forward else separate_forward + if args.packing: + if not args.concatenated_forward: + raise NotImplementedError( + "seperate forward not implemented for packing/padding-free" + ) + forward_fn = partial(forward_fn, packing=True) if args.dpo_loss_type == "dpo" or args.dpo_loss_type == "dpo_norm": epoch_cached_reference_chosen_logps, epoch_cached_reference_rejected_logps = get_cache_ref_logprobs( model, diff --git a/open_instruct/dpo_utils.py b/open_instruct/dpo_utils.py index 3555b50414..e848b68c44 100644 --- a/open_instruct/dpo_utils.py +++ b/open_instruct/dpo_utils.py @@ -26,6 +26,10 @@ from transformers import DataCollatorForSeq2Seq from open_instruct.model_utils import log_softmax_and_gather +from open_instruct.padding_free_collator import ( + concatenated_inputs as pf_concatenated_inputs, + get_batch_logps as pf_get_batch_logps +) torch.backends.cuda.matmul.allow_tf32 = True @@ -225,29 +229,47 @@ def concatenated_forward( batch: Dict[str, Union[List, torch.LongTensor]], average_log_prob: bool = False, output_router_logits: bool = False, + packing: bool = False, ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. We do this to avoid doing two forward passes, because it's faster for FSDP. """ - concatenated_batch = concatenated_inputs(batch) + if not packing: + concatenated_batch = concatenated_inputs(batch) + else: + concatenated_batch, bs = pf_concatenated_inputs(batch) + + inputs = { + k.replace("concatenated_", ""):v + for k, v in concatenated_batch.items() if k.startswith("concatenated_") and not k.endswith("labels") + + } if output_router_logits: outputs = model( - input_ids=concatenated_batch["concatenated_input_ids"], - attention_mask=concatenated_batch["concatenated_attention_mask"], + **inputs, output_router_logits=True, ) logits = outputs.logits.to(torch.float32) aux_loss = outputs.aux_loss else: logits = model( - input_ids=concatenated_batch["concatenated_input_ids"], - attention_mask=concatenated_batch["concatenated_attention_mask"], + **inputs ).logits.to(torch.float32) aux_loss = None - all_logps = _get_batch_logps(logits, concatenated_batch["concatenated_labels"], average_log_prob=average_log_prob) - chosen_logps = all_logps[: batch["chosen_input_ids"].shape[0]] - rejected_logps = all_logps[batch["chosen_input_ids"].shape[0] :] + + if not packing: + all_logps = _get_batch_logps(logits, concatenated_batch["concatenated_labels"], average_log_prob=average_log_prob) + bs = batch["chosen_input_ids"].shape[0] + else: + all_logps = pf_get_batch_logps( + logits, + concatenated_batch["concatenated_labels"], + inputs['cu_seq_lens_k'], # assume same as q + average_log_prob=average_log_prob + ) + chosen_logps = all_logps[:bs] + rejected_logps = all_logps[bs:] return chosen_logps, rejected_logps, aux_loss diff --git a/open_instruct/padding_free_collator.py b/open_instruct/padding_free_collator.py index 2f9e6b2e93..7f37ef8e60 100644 --- a/open_instruct/padding_free_collator.py +++ b/open_instruct/padding_free_collator.py @@ -3,6 +3,7 @@ import torch from transformers import DefaultDataCollator +from typing import Dict, Union, List @dataclass @@ -19,24 +20,10 @@ class TensorDataCollatorWithFlattening(DefaultDataCollator): batch size 1, with additional information included in the batch to demarcate example boundaries. """ - def __init__( - self, - *args, - return_flash_attn_kwargs=True, - return_position_ids=True, - return_seq_idx=True, - separator_id=-100, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.return_flash_attn_kwargs = return_flash_attn_kwargs - self.return_position_ids = return_position_ids - self.return_seq_idx = return_seq_idx - self.separator_id = separator_id - warnings.warn( - "Using `TensorDataCollatorWithFlattening` will flatten the entire mini batch into a " - "single long sequence. Make sure your attention computation is able to handle it!" - ) + return_flash_attn_kwargs: bool = True + return_position_ids: bool = True + return_seq_idx: bool = True + separator_id: int = -100 def __call__(self, features, return_tensors=None, separator_id=None): if return_tensors is None: @@ -84,3 +71,119 @@ def __call__(self, features, return_tensors=None, separator_id=None): ret["input_ids"] = torch.cat(ret["input_ids"], dim=0)[None] ret["labels"] = torch.cat(ret["labels"], dim=0)[None] return ret + + +@dataclass +class TensorDataCollatorWithFlatteningDPO(TensorDataCollatorWithFlattening): + + def __call__(self, features, return_tensors=None): + + # call the original collator on chosen and rejected separately, then combine + def filter_batch(match_string, features): + return [{k.replace(match_string, ""): v for k, v in f.items() if match_string in k} for f in features] + + chosen_features = super().__call__(filter_batch("chosen_", features), return_tensors=return_tensors) + rejected_features = super().__call__(filter_batch("rejected_", features), return_tensors=return_tensors) + + result = {} + for k in chosen_features: + result["chosen_" + k] = chosen_features[k] + for k in rejected_features: + result["rejected_" + k] = rejected_features[k] + return result + +# - dpo concatenation for padding free +def concatenated_inputs( + batch: Dict[str, Union[List, torch.LongTensor]], + tag: str = 'concatenated_', +) -> Dict[str, torch.LongTensor]: + + chosen_features, rejected_features = {}, {} + for k in batch: + if k.startswith('chosen_'): + chosen_features[k.replace("chosen_", "")] = batch[k] + else: + rejected_features[k.replace("rejected_", "")] = batch[k] + + # - need to return chosen + ret = { + f'{tag}input_ids': torch.cat([chosen_features['input_ids'], rejected_features['input_ids']], axis=-1) + } + if "labels" in chosen_features: + ret[f'{tag}labels'] = torch.cat([chosen_features['labels'], rejected_features['labels']], axis=-1) + + if "cu_seq_lens_q" in chosen_features: + ret[f'{tag}cu_seq_lens_q'] = torch.cat([ + chosen_features['cu_seq_lens_q'], + rejected_features['cu_seq_lens_q'][1:] + chosen_features['cu_seq_lens_q'][-1], + ]) + ret[f'{tag}cu_seq_lens_k'] = torch.cat([ + chosen_features['cu_seq_lens_k'], + rejected_features['cu_seq_lens_k'][1:] + chosen_features['cu_seq_lens_k'][-1], + ]) + ret[f'{tag}max_length_q'] = max( + chosen_features['max_length_q'], + rejected_features['max_length_q'], + ) + ret[f'{tag}max_length_k'] = max( + chosen_features['max_length_k'], + rejected_features['max_length_k'], + ) + + if "position_ids" in chosen_features: + ret[f"{tag}position_ids"] = torch.cat([ + chosen_features['position_ids'], + rejected_features['position_ids'] + ], dim=-1) + + if 'seq_idx' in chosen_features: + ret[f"{tag}seq_idx"] = torch.cat([ + chosen_features['seq_idx'], + rejected_features['seq_idx'] + + chosen_features['seq_idx'][0,-1], + ], dim=-1) + + return ret, len(chosen_features['cu_seq_lens_k']) - 1 + +# for dpo - padding free +def get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + cu_seq_lens: torch.LongTensor, + average_log_prob: bool = False, +) -> torch.FloatTensor: + + assert logits.shape[:-1] == labels.shape + + # - we are going to get crossings at labels / logits + # cont batch boundaries, but we assume that the + # loss mask == True at those places + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != -100 + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == -100] = 0 + + # there is a labels, logits shift operation above + cu_seq_lens = cu_seq_lens.clone() - 1 + cu_seq_lens[0] = 0 + + splits = cu_seq_lens.diff().tolist() + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + return torch.concat( + [ + ( + (ps * mask).sum(-1) / mask.sum(-1) + if average_log_prob else + (ps * mask).sum(-1) + ) + for ps, mask in + zip( + torch.split(per_token_logps, splits, dim=-1), + torch.split(loss_mask, splits, dim=-1), + ) + + ] + ) diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 13354ab2e7..ec12b0ef78 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -261,6 +261,8 @@ def get_datasets( # if dataset ends with .json or .jsonl, load from file if ds.endswith(".json") or ds.endswith(".jsonl"): dataset = load_dataset("json", data_files=ds, split=split) + elif ds.endswith(".parquet"): + dataset = load_dataset("parquet", data_files=ds, split=split) else: try: # Try first if dataset on a Hub repo