Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions open_instruct/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
visualize_token,
)
from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate
from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening
from open_instruct.utils import (
ArgumentParserPlus,
clean_last_n_checkpoints,
Expand Down Expand Up @@ -362,6 +363,11 @@ class FlatArguments:
oe_eval_max_length: int = 4096
"""the max generation length for evaluation for oe-eval"""

padding_free: bool = field(
default=False,
metadata={"help": "Whether to use padding-free collation via TensorDataCollatorWithFlattening"},
)

def __post_init__(self):
if self.reduce_loss not in ["mean", "sum"]:
raise ValueError("reduce_loss must be either 'mean' or 'sum'")
Expand Down Expand Up @@ -630,10 +636,18 @@ def main(args: FlatArguments, tc: TokenizerConfig):
model.gradient_checkpointing_enable()

# DataLoaders creation:
if args.padding_free:
collate_fn = TensorDataCollatorWithFlattening()
else:
collate_fn = DataCollatorForSeq2Seq(
tokenizer=tokenizer, model=model, padding="longest"
)

accelerator.print("Creating dataloader")
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
collate_fn=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"),
collate_fn=collate_fn,
batch_size=args.per_device_train_batch_size,
)

Expand Down Expand Up @@ -765,8 +779,21 @@ def main(args: FlatArguments, tc: TokenizerConfig):
else:
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
local_total_tokens += batch["attention_mask"].sum()
total_token_including_padding += batch["attention_mask"].numel()
if "attention_mask" in batch:
local_total_tokens += batch["attention_mask"].sum()
total_token_including_padding += batch["attention_mask"].numel()
elif "position_ids" in batch:
tokens_in_batch = batch["position_ids"].numel()
local_total_tokens += tokens_in_batch
total_token_including_padding += tokens_in_batch
elif "cu_seq_lens_q" in batch:
tokens_in_batch = batch["cu_seq_lens_q"][-1]
local_total_tokens += tokens_in_batch
total_token_including_padding += tokens_in_batch
else:
raise ValueError(
f"Expected attention_mask or position_ids or cu_seq_lens_q in batch, found {batch=}"
)
with accelerator.accumulate(model):
if args.load_balancing_loss:
outputs = model(**batch, use_cache=False, output_router_logits=True)
Expand Down
87 changes: 87 additions & 0 deletions open_instruct/padding_free_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import warnings
from dataclasses import dataclass

import torch
from transformers import DefaultDataCollator


@dataclass
class TensorDataCollatorWithFlattening(DefaultDataCollator):
"""
Data collator used for padding free approach. Does the following:

- concatate the entire mini batch into single long sequence [1, total_tokens]
- uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100
- no padding will be added, returns `input_ids`, `labels` and `position_ids`
"""

def __init__(
self,
*args,
return_flash_attn_kwargs=True,
return_position_ids=True,
return_seq_idx=True,
separator_id=-100,
**kwargs,
):
super().__init__(*args, **kwargs)
self.return_flash_attn_kwargs = return_flash_attn_kwargs
self.return_position_ids = return_position_ids
self.return_seq_idx = return_seq_idx
self.separator_id = separator_id
warnings.warn(
"Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence."
"Make sure your attention computation is able to handle it!"
)

def __call__(self, features, return_tensors=None, separator_id=None):
if return_tensors is None:
return_tensors = self.return_tensors
if separator_id is None:
separator_id = self.separator_id
assert self.return_flash_attn_kwargs, (
"Only should be used with return_flash_attn_kwargs=True"
)
if self.return_flash_attn_kwargs:
cu_seq_lens = [0]
max_length = 0
if self.return_position_ids:
pos_ids = []
if self.return_seq_idx:
seq_idx = []
is_labels_provided = "labels" in features[0]
ret = {"input_ids": [], "labels": []}
separator = torch.tensor(
[separator_id],
dtype=features[0]["input_ids"].dtype,
device=features[0]["input_ids"].device,
)
for s_idx, item in enumerate(features):
input_ids = item["input_ids"]
ret["input_ids"].append(input_ids)
if is_labels_provided:
ret["labels"].append(separator)
ret["labels"].append(item["labels"][1:])
else:
ret["labels"].append(separator)
ret["labels"].append(input_ids[1:])
if self.return_flash_attn_kwargs:
cu_seq_lens.append(cu_seq_lens[-1] + len(input_ids))
max_length = max(max_length, len(input_ids))
if self.return_position_ids:
pos_ids.append(torch.arange(input_ids.numel(), device=input_ids.device))
if self.return_seq_idx:
seq_idx.append(torch.full_like(input_ids, s_idx, dtype=torch.int32))

if self.return_flash_attn_kwargs:
ret["cu_seq_lens_q"] = ret["cu_seq_lens_k"] = torch.tensor(
cu_seq_lens, dtype=torch.int32, device=features[0]["input_ids"].device
)
ret["max_length_q"] = ret["max_length_k"] = max_length
if self.return_position_ids:
ret["position_ids"] = torch.cat(pos_ids, dim=0)[None]
if self.return_seq_idx:
ret["seq_idx"] = torch.cat(seq_idx, dim=0)[None]
ret["input_ids"] = torch.cat(ret["input_ids"], dim=0)[None]
ret["labels"] = torch.cat(ret["labels"], dim=0)[None]
return ret
127 changes: 127 additions & 0 deletions tests/test_padding_free.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import sys
from pathlib import Path

import pytest
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import (
BambaConfig,
BambaForCausalLM,
LlamaConfig,
LlamaForCausalLM,
)

# HACK for being able to load the collator without needing to install open-instruct
open_instruct_dir = Path(__file__).parent.parent.absolute()
sys.path.append(open_instruct_dir)
from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening

MODEL_CLASSES = {"bamba": BambaForCausalLM, "llama": LlamaForCausalLM}
MODEL_CFGS = {
"bamba": BambaConfig,
"llama": LlamaConfig,
}
MODEL_KWARGS = {
"bamba": dict(
attention_dropout=0.0,
attn_layer_indices=None,
attn_rotary_emb=8,
hidden_act="silu",
hidden_size=32,
initializer_range=0.02,
intermediate_size=64,
mamba_chunk_size=16,
mamba_d_conv=4,
mamba_d_state=16,
mamba_expand=2,
mamba_n_groups=1,
mamba_n_heads=16,
max_position_embeddings=512,
num_attention_heads=4,
num_hidden_layers=4,
num_key_value_heads=2,
pad_token_id=0,
vocab_size=99,
),
"llama": dict(
hidden_act="gelu",
hidden_size=32,
intermediate_size=64,
is_training=True,
max_position_embeddings=512,
mlp_bias=False,
num_attention_heads=2,
num_hidden_layers=2,
num_key_value_heads=2,
vocab_size=99,
),
}


class TestPaddingFree:
seqlen = 128
batch_size = 2
dtype = torch.bfloat16

def get_fa2_model_and_cfg(self, model_name: str) -> nn.Module:
model_cls = MODEL_CLASSES[model_name]
model_cfg = MODEL_CFGS[model_name]
model_kwargs = MODEL_KWARGS[model_name]
cfg = model_cfg(
**{
**model_kwargs,
"torch_dtype": self.dtype,
"attn_implementation": "flash_attention_2",
}
)
model = model_cls(cfg).to("cuda", dtype=self.dtype)
return model, cfg

@pytest.mark.parametrize("model_name", ["bamba", "llama"])
def test_padding_free(self, model_name: str) -> None:
torch.manual_seed(42)
model, cfg = self.get_fa2_model_and_cfg(model_name)

inputs = torch.randint(cfg.vocab_size, size=(self.batch_size, self.seqlen), device="cuda")
# Non-padding-free batch:
batch = {
"input_ids": inputs,
"labels": inputs,
"attention_mask": torch.ones_like(inputs),
}

# Padding-free batch from the collator
dataset = {idx: {"input_ids": example} for idx, example in enumerate(inputs)}
collate_fn = TensorDataCollatorWithFlattening()
train_dataloader = DataLoader(
dataset,
shuffle=False,
collate_fn=collate_fn,
batch_size=self.batch_size,
)
pf_batch = next(iter(train_dataloader))
assert batch["input_ids"].shape[0] == 2
assert pf_batch["input_ids"].shape[0] == 1

# Also create a batch with the pf style concatenation, but without the pf seq markers as a
# control. Passing this through the model should give incorrect results.

incorrect_pf_batch = {
"input_ids": pf_batch["input_ids"],
"labels": pf_batch["labels"],
"attention_mask": torch.ones_like(pf_batch["input_ids"]),
}

with torch.no_grad():
outputs = model(**batch)
pf_outputs = model(**pf_batch)
incorrect_pf_outputs = model(**incorrect_pf_batch)

logits = outputs.logits.reshape(1, -1, outputs.logits.shape[-1])
pf_logits = pf_outputs.logits
incorrect_pf_logits = incorrect_pf_outputs.logits
torch.testing.assert_close(pf_logits, logits)
torch.testing.assert_close(outputs.loss, pf_outputs.loss)
with pytest.raises(AssertionError, match="Mismatched elements:"):
torch.testing.assert_close(pf_logits, incorrect_pf_logits)