From ebeefe1882bed38a373788e5ed1f463ece81812c Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Wed, 25 Jun 2025 14:55:25 +0000 Subject: [PATCH 1/4] padding-free --- open_instruct/finetune.py | 32 ++++- open_instruct/padding_free_collator.py | 87 +++++++++++ tests/test_padding_free.py | 191 +++++++++++++++++++++++++ 3 files changed, 307 insertions(+), 3 deletions(-) create mode 100644 open_instruct/padding_free_collator.py create mode 100644 tests/test_padding_free.py diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index f1a8c3c15f..44338d2a73 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -60,6 +60,7 @@ visualize_token, ) from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate +from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening from open_instruct.utils import ( ArgumentParserPlus, clean_last_n_checkpoints, @@ -325,6 +326,10 @@ class FlatArguments: sync_each_batch: bool = False """Optionaly sync grads every batch when using grad accumulation. Can significantly reduce memory costs.""" + padding_free: bool = field( + default=False, + metadata={"help": "Whether to use padding-free collation via TensorDataCollatorWithFlattening"}, + ) def __post_init__(self): if self.reduce_loss not in ["mean", "sum"]: @@ -607,10 +612,18 @@ def main(args: FlatArguments, tc: TokenizerConfig): model.gradient_checkpointing_enable() # DataLoaders creation: + if args.padding_free: + collate_fn = TensorDataCollatorWithFlattening() + else: + collate_fn = DataCollatorForSeq2Seq( + tokenizer=tokenizer, model=model, padding="longest" + ) + + accelerator.print("Creating dataloader") train_dataloader = DataLoader( train_dataset, shuffle=True, - collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"), + collate_fn=collate_fn, batch_size=args.per_device_train_batch_size, ) @@ -743,8 +756,21 @@ def main(args: FlatArguments, tc: TokenizerConfig): else: active_dataloader = train_dataloader for step, batch in enumerate(active_dataloader): - local_total_tokens += batch["attention_mask"].sum() - total_token_including_padding += batch["attention_mask"].numel() + if "attention_mask" in batch: + local_total_tokens += batch["attention_mask"].sum() + total_token_including_padding += batch["attention_mask"].numel() + elif "position_ids" in batch: + tokens_in_batch = batch["position_ids"].numel() + local_total_tokens += tokens_in_batch + total_token_including_padding += tokens_in_batch + elif "cu_seq_lens_q" in batch: + tokens_in_batch = batch["cu_seq_lens_q"][-1] + local_total_tokens += tokens_in_batch + total_token_including_padding += tokens_in_batch + else: + raise ValueError( + f"Expected attention_mask or position_ids or cu_seq_lens_q in batch, found {batch=}" + ) with accelerator.accumulate(model): if args.load_balancing_loss: outputs = model(**batch, use_cache=False, output_router_logits=True) diff --git a/open_instruct/padding_free_collator.py b/open_instruct/padding_free_collator.py new file mode 100644 index 0000000000..56b2d93650 --- /dev/null +++ b/open_instruct/padding_free_collator.py @@ -0,0 +1,87 @@ +import warnings +from dataclasses import dataclass + +import torch +from transformers import DefaultDataCollator + + +@dataclass +class TensorDataCollatorWithFlattening(DefaultDataCollator): + """ + Data collator used for padding free approach. Does the following: + + - concatate the entire mini batch into single long sequence [1, total_tokens] + - uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100 + - no padding will be added, returns `input_ids`, `labels` and `position_ids` + """ + + def __init__( + self, + *args, + return_flash_attn_kwargs=True, + return_position_ids=True, + return_seq_idx=True, + separator_id=-100, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.return_flash_attn_kwargs = return_flash_attn_kwargs + self.return_position_ids = return_position_ids + self.return_seq_idx = return_seq_idx + self.separator_id = separator_id + warnings.warn( + "Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence." + "Make sure your attention computation is able to handle it!" + ) + + def __call__(self, features, return_tensors=None, separator_id=None): + if return_tensors is None: + return_tensors = self.return_tensors + if separator_id is None: + separator_id = self.separator_id + assert self.return_flash_attn_kwargs, ( + "Only should be used with return_flash_attn_kwargs=True" + ) + if self.return_flash_attn_kwargs: + cu_seq_lens = [0] + max_length = 0 + if self.return_position_ids: + pos_ids = [] + if self.return_seq_idx: + seq_idx = [] + is_labels_provided = "labels" in features[0] + ret = {"input_ids": [], "labels": []} + separator = torch.tensor( + [separator_id], + dtype=features[0]["input_ids"].dtype, + device=features[0]["input_ids"].device, + ) + for s_idx, item in enumerate(features): + input_ids = item["input_ids"] + ret["input_ids"].append(input_ids) + if is_labels_provided: + ret["labels"].append(separator) + ret["labels"].append(item["labels"][1:]) + else: + ret["labels"].append(separator) + ret["labels"].append(input_ids[1:]) + if self.return_flash_attn_kwargs: + cu_seq_lens.append(cu_seq_lens[-1] + len(input_ids)) + max_length = max(max_length, len(input_ids)) + if self.return_position_ids: + pos_ids.append(torch.arange(input_ids.numel(), device=input_ids.device)) + if self.return_seq_idx: + seq_idx.append(torch.full_like(input_ids, s_idx, dtype=torch.int32)) + + if self.return_flash_attn_kwargs: + ret["cu_seq_lens_q"] = ret["cu_seq_lens_k"] = torch.tensor( + cu_seq_lens, dtype=torch.int32, device=features[0]["input_ids"].device + ) + ret["max_length_q"] = ret["max_length_k"] = max_length + if self.return_position_ids: + ret["position_ids"] = torch.cat(pos_ids, dim=0)[None] + if self.return_seq_idx: + ret["seq_idx"] = torch.cat(seq_idx, dim=0)[None] + ret["input_ids"] = torch.cat(ret["input_ids"], dim=0)[None] + ret["labels"] = torch.cat(ret["labels"], dim=0)[None] + return ret diff --git a/tests/test_padding_free.py b/tests/test_padding_free.py new file mode 100644 index 0000000000..7333d5b033 --- /dev/null +++ b/tests/test_padding_free.py @@ -0,0 +1,191 @@ +import sys +from copy import deepcopy +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from transformers import ( + AutoTokenizer, + BambaConfig, + BambaForCausalLM, + DataCollatorForSeq2Seq, + LlamaConfig, + LlamaForCausalLM, +) + +# HACK for being able to load the collator without needing to install open-instruct +open_instruct_dir = Path(__file__).parent.parent.absolute() +sys.path.append(open_instruct_dir) +from open_instruct.dataset_processor import CHAT_TEMPLATES +from open_instruct.dataset_transformation import sft_tulu_tokenize_and_truncate_v1 +from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening + +MODEL_CLASSES = {"bamba": BambaForCausalLM, "llama": LlamaForCausalLM} +MODEL_CFGS = { + "bamba": BambaConfig, + "llama": LlamaConfig, +} +MODEL_KWARGS = { + "bamba": dict( + attention_dropout=0.0, + attn_layer_indices=None, + attn_rotary_emb=8, + hidden_act="silu", + hidden_size=32, + initializer_range=0.02, + intermediate_size=64, + mamba_chunk_size=16, + mamba_d_conv=4, + mamba_d_state=16, + mamba_expand=2, + mamba_n_groups=1, + mamba_n_heads=16, + max_position_embeddings=512, + num_attention_heads=4, + num_hidden_layers=1, + num_key_value_heads=2, + pad_token_id=0, + ), + "llama": dict( + hidden_act="gelu", + hidden_size=32, + intermediate_size=64, + is_training=True, + max_position_embeddings=512, + mlp_bias=False, + num_attention_heads=2, + num_hidden_layers=1, + num_key_value_heads=2, + ), +} + + +class TestPaddingFree: + seqlen = 128 + batch_size = 2 + dtype = torch.bfloat16 + + def get_fa2_model_and_cfg(self, model_name: str, vocab_size: int) -> nn.Module: + model_cls = MODEL_CLASSES[model_name] + model_cfg = MODEL_CFGS[model_name] + model_kwargs = MODEL_KWARGS[model_name] + cfg = model_cfg( + **{ + **model_kwargs, + "torch_dtype": self.dtype, + "attn_implementation": "flash_attention_2", + "vocab_size": vocab_size, + } + ) + model = model_cls(cfg).to("cuda", dtype=self.dtype) + return model, cfg + + @pytest.mark.parametrize("model_name", ["bamba", "llama"]) + @pytest.mark.parametrize("loss_type", ["mean", "sum"]) + def test_padding_free(self, model_name: str, loss_type: str) -> None: + torch.manual_seed(42) + + tokenizer = AutoTokenizer.from_pretrained("ibm-ai-platform/Bamba-9B-v2") + tokenizer.add_special_tokens({"pad_token": ""}) + tokenizer.chat_template = CHAT_TEMPLATES["tulu"] + vocab_size = len(tokenizer) + + model, cfg = self.get_fa2_model_and_cfg(model_name, vocab_size) + model.initialize_weights() + pf_model = deepcopy(model) + + inputs = torch.randint(cfg.vocab_size, size=(self.batch_size, self.seqlen), device="cpu") + + data = { + 0: { + "messages": [ + {"role": "user", "content": "Why did the chicken cross the road?"}, + {"role": "assistant", "content": "To get to the other side"}, + ] + }, + 1: { + "messages": [ + {"role": "user", "content": "What is one plus two?"}, + {"role": "assistant", "content": "The answer is 3"}, + ] + }, + } + + tok_data = {k: sft_tulu_tokenize_and_truncate_v1(v, tokenizer, max_seq_length=2**30) for k, v in data.items()} + for v in tok_data.values(): + del v["messages"] + + collate_fn = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest") + dataloader = DataLoader( + tok_data, + shuffle=False, + collate_fn=collate_fn, + batch_size=self.batch_size, + ) + + pf_collate_fn = TensorDataCollatorWithFlattening() + pf_dataloader = DataLoader( + tok_data, + shuffle=False, + collate_fn=pf_collate_fn, + batch_size=self.batch_size, + ) + + batch = next(iter(dataloader)) + pf_batch = next(iter(pf_dataloader)) + for b in (batch, pf_batch): + for k in b: + if torch.is_tensor(b[k]): + b[k] = b[k].cuda() + + assert batch["input_ids"].shape[0] == 2 + assert pf_batch["input_ids"].shape[0] == 1 + + # Also create a batch with the pf style concatenation, but without the pf seq markers as a + # control. Passing this through the model should give incorrect results. + + incorrect_pf_batch = { + "input_ids": pf_batch["input_ids"], + "labels": pf_batch["labels"], + "attention_mask": torch.ones_like(pf_batch["input_ids"]), + } + + outputs = model(**batch) + pf_outputs = pf_model(**pf_batch) + with torch.no_grad(): + incorrect_pf_outputs = model(**incorrect_pf_batch) + + # Compare logits (properly reshaped and masked) + logits = outputs.logits.reshape(1, -1, outputs.logits.shape[-1]) + non_masked_logits = logits[:, batch["attention_mask"].flatten().bool()] + pf_logits = pf_outputs.logits + incorrect_pf_logits = incorrect_pf_outputs.logits + torch.testing.assert_close(pf_logits, non_masked_logits) + with pytest.raises(AssertionError, match="Mismatched elements:"): + torch.testing.assert_close(pf_logits, incorrect_pf_logits) + + if loss_type == "mean": + loss = outputs.loss + pf_loss = pf_outputs.loss + else: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), batch["labels"].view(-1).long(), reduce="sum") + pf_loss = F.cross_entropy( + pf_logits.view(-1, pf_logits.size(-1)), pf_batch["labels"].view(-1).long(), reduce="sum" + ) + torch.testing.assert_close(loss, pf_loss) + + loss.backward() + pf_loss.backward() + + grads = {n: p.grad for n, p in model.named_parameters()} + pf_grads = {n: p.grad for n, p in pf_model.named_parameters()} + non_nan_grads = set() + nan_grads = set() + for k, g in grads.items(): + torch.testing.assert_close(g, pf_grads[k]) + non_nan_grads.add(k) + print(f"{non_nan_grads=}") + print(f"{nan_grads=}") From 489b0a528ce880a1166fc3f1a338cd74669bab80 Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Thu, 3 Jul 2025 19:10:51 +0000 Subject: [PATCH 2/4] cleanup --- open_instruct/padding_free_collator.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/open_instruct/padding_free_collator.py b/open_instruct/padding_free_collator.py index 56b2d93650..2f9e6b2e93 100644 --- a/open_instruct/padding_free_collator.py +++ b/open_instruct/padding_free_collator.py @@ -8,11 +8,15 @@ @dataclass class TensorDataCollatorWithFlattening(DefaultDataCollator): """ - Data collator used for padding free approach. Does the following: + Data collator for padding-free training along the lines of https://huggingface.co/blog/packing-with-FA2 - - concatate the entire mini batch into single long sequence [1, total_tokens] - - uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100 - - no padding will be added, returns `input_ids`, `labels` and `position_ids` + Eliminates use of padding which is generically needed for per_gpu_batch_size > 1, thereby + reducing memory costs and increasing throughput. Your model class must support padding-free + training to use this collator correctly. Examples which support padding free include + LlamaForCausalLM and BambaForCausalLM. + + The `input_ids` and `labels` from separate examples are concatenated together into a tensor of + batch size 1, with additional information included in the batch to demarcate example boundaries. """ def __init__( @@ -30,8 +34,8 @@ def __init__( self.return_seq_idx = return_seq_idx self.separator_id = separator_id warnings.warn( - "Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence." - "Make sure your attention computation is able to handle it!" + "Using `TensorDataCollatorWithFlattening` will flatten the entire mini batch into a " + "single long sequence. Make sure your attention computation is able to handle it!" ) def __call__(self, features, return_tensors=None, separator_id=None): @@ -39,9 +43,6 @@ def __call__(self, features, return_tensors=None, separator_id=None): return_tensors = self.return_tensors if separator_id is None: separator_id = self.separator_id - assert self.return_flash_attn_kwargs, ( - "Only should be used with return_flash_attn_kwargs=True" - ) if self.return_flash_attn_kwargs: cu_seq_lens = [0] max_length = 0 @@ -52,9 +53,7 @@ def __call__(self, features, return_tensors=None, separator_id=None): is_labels_provided = "labels" in features[0] ret = {"input_ids": [], "labels": []} separator = torch.tensor( - [separator_id], - dtype=features[0]["input_ids"].dtype, - device=features[0]["input_ids"].device, + [separator_id], dtype=features[0]["input_ids"].dtype, device=features[0]["input_ids"].device ) for s_idx, item in enumerate(features): input_ids = item["input_ids"] From 312241a556d70431288c0d1f8f5190e6c4c3aaaa Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Mon, 7 Jul 2025 15:08:42 +0000 Subject: [PATCH 3/4] test skips --- tests/test_padding_free.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/tests/test_padding_free.py b/tests/test_padding_free.py index 7333d5b033..d735b6cc91 100644 --- a/tests/test_padding_free.py +++ b/tests/test_padding_free.py @@ -23,11 +23,23 @@ from open_instruct.dataset_transformation import sft_tulu_tokenize_and_truncate_v1 from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening +try: + import mamba_ssm # noqa + import causal_conv1d # noqa + + mamba_and_causal_conv_available = True +except ImportError: + mamba_and_causal_conv_available = False + +try: + import flash_attn # noqa + + flash_attn_available = True +except ImportError: + flash_attn_available = False + MODEL_CLASSES = {"bamba": BambaForCausalLM, "llama": LlamaForCausalLM} -MODEL_CFGS = { - "bamba": BambaConfig, - "llama": LlamaConfig, -} +MODEL_CFGS = {"bamba": BambaConfig, "llama": LlamaConfig} MODEL_KWARGS = { "bamba": dict( attention_dropout=0.0, @@ -83,9 +95,13 @@ def get_fa2_model_and_cfg(self, model_name: str, vocab_size: int) -> nn.Module: model = model_cls(cfg).to("cuda", dtype=self.dtype) return model, cfg + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Padding free tests require CUDA") + @pytest.mark.skipif(not flash_attn_available, reason="Padding free requires flash_attn") @pytest.mark.parametrize("model_name", ["bamba", "llama"]) @pytest.mark.parametrize("loss_type", ["mean", "sum"]) def test_padding_free(self, model_name: str, loss_type: str) -> None: + if model_name == "bamba" and not mamba_and_causal_conv_available: + pytest.skip("bamba padding-free tests require mamba_ssm and causal_conv1d") torch.manual_seed(42) tokenizer = AutoTokenizer.from_pretrained("ibm-ai-platform/Bamba-9B-v2") @@ -119,20 +135,10 @@ def test_padding_free(self, model_name: str, loss_type: str) -> None: del v["messages"] collate_fn = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest") - dataloader = DataLoader( - tok_data, - shuffle=False, - collate_fn=collate_fn, - batch_size=self.batch_size, - ) + dataloader = DataLoader(tok_data, shuffle=False, collate_fn=collate_fn, batch_size=self.batch_size) pf_collate_fn = TensorDataCollatorWithFlattening() - pf_dataloader = DataLoader( - tok_data, - shuffle=False, - collate_fn=pf_collate_fn, - batch_size=self.batch_size, - ) + pf_dataloader = DataLoader(tok_data, shuffle=False, collate_fn=pf_collate_fn, batch_size=self.batch_size) batch = next(iter(dataloader)) pf_batch = next(iter(pf_dataloader)) From 8d7e330b4873f4282465b01e96f8af3e3d232090 Mon Sep 17 00:00:00 2001 From: Garrett Goon Date: Tue, 8 Jul 2025 00:27:38 +0000 Subject: [PATCH 4/4] padding_free -> packing --- open_instruct/finetune.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 44338d2a73..279d146949 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -326,9 +326,9 @@ class FlatArguments: sync_each_batch: bool = False """Optionaly sync grads every batch when using grad accumulation. Can significantly reduce memory costs.""" - padding_free: bool = field( + packing: bool = field( default=False, - metadata={"help": "Whether to use padding-free collation via TensorDataCollatorWithFlattening"}, + metadata={"help": "Whether to use packing/padding-free collation via TensorDataCollatorWithFlattening"}, ) def __post_init__(self): @@ -612,7 +612,7 @@ def main(args: FlatArguments, tc: TokenizerConfig): model.gradient_checkpointing_enable() # DataLoaders creation: - if args.padding_free: + if args.packing: collate_fn = TensorDataCollatorWithFlattening() else: collate_fn = DataCollatorForSeq2Seq(