Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions open_instruct/dataset_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1427,6 +1427,40 @@ def preference_tulu_filter_v1(row: Dict[str, Any], tokenizer: PreTrainedTokenize
return any(x != -100 for x in row[CHOSEN_LABELS_KEY]) and any(x != -100 for x in row[REJECTED_LABELS_KEY])


def preference_span_search_mask_out(
row: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
max_seq_length: int,
chosen_key: str = DEFAULT_CHOSEN_KEY,
rejected_key: str = DEFAULT_REJECTED_KEY,
):
"""
Here we assume each example has a rejected and chosen field, both of which are a list of messages.
Each message is a dict with 'role' and 'content' fields.
We assume only the last message is different, and the prompt is contained in the list of messages.
"""
chosen_messages = row[chosen_key]
rejected_messages = row[rejected_key]
if len(chosen_messages) == 0:
raise ValueError("chosen messages field is empty.")
if len(rejected_messages) == 0:
raise ValueError("rejected messages field is empty.")

chosen_encoded = sft_span_seach_mask_out({DEFAULT_SFT_MESSAGES_KEY: chosen_messages}, tokenizer, max_seq_length)
rejected_encoded = sft_span_seach_mask_out(
{DEFAULT_SFT_MESSAGES_KEY: rejected_messages}, tokenizer, max_seq_length
)

return {
CHOSEN_INPUT_IDS_KEY: chosen_encoded["input_ids"],
CHOSEN_LABELS_KEY: chosen_encoded["labels"],
CHOSEN_ATTENTION_MASK_KEY: chosen_encoded["attention_mask"],
REJECTED_INPUT_IDS_KEY: rejected_encoded["input_ids"],
REJECTED_LABELS_KEY: rejected_encoded["labels"],
REJECTED_ATTENTION_MASK_KEY: rejected_encoded["attention_mask"],
}


def rlvr_tokenize_v1(
row: Dict[str, Any],
tokenizer: PreTrainedTokenizer,
Expand Down Expand Up @@ -1514,6 +1548,7 @@ def rlvr_filter_v1(
"preference_tokenize_v1": (preference_tokenize_v1, "map"),
"preference_filter_v1": (preference_filter_v1, "filter"),
"preference_tulu_tokenize_and_truncate_v1": (preference_tulu_tokenize_and_truncate_v1_2, "map"),
"preference_span_search_mask_out": (preference_span_search_mask_out, "map"),
"preference_tulu_filter_v1": (preference_tulu_filter_v1, "filter"),
"rlvr_tokenize_v1": (rlvr_tokenize_v2, "map"),
"rlvr_filter_v1": (rlvr_filter_v1, "filter"),
Expand Down
7 changes: 5 additions & 2 deletions open_instruct/padding_free_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
from transformers import DefaultDataCollator

from open_instruct.model_utils import log_softmax_and_gather


@dataclass
class TensorDataCollatorWithFlattening(DefaultDataCollator):
Expand Down Expand Up @@ -129,7 +131,8 @@ def concatenated_inputs(

if "seq_idx" in chosen_features:
ret[f"{tag}seq_idx"] = torch.cat(
[chosen_features["seq_idx"], rejected_features["seq_idx"] + chosen_features["seq_idx"][0, -1]], dim=-1
[chosen_features["seq_idx"], rejected_features["seq_idx"] + len(chosen_features["cu_seq_lens_k"]) - 1],
dim=-1,
)

return ret, len(chosen_features["cu_seq_lens_k"]) - 1
Expand All @@ -156,7 +159,7 @@ def get_batch_logps(
cu_seq_lens[0] = 0

splits = cu_seq_lens.diff().tolist()
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps = log_softmax_and_gather(logits, labels)

return torch.concat(
[
Expand Down
89 changes: 85 additions & 4 deletions tests/test_padding_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,19 @@
)

from open_instruct.dataset_processor import CHAT_TEMPLATES
from open_instruct.dataset_transformation import sft_tulu_tokenize_and_truncate_v1
from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening
from open_instruct.dataset_transformation import (
sft_tulu_tokenize_and_truncate_v1,
preference_tulu_tokenize_and_truncate_v1_2,
)
from open_instruct.padding_free_collator import (
TensorDataCollatorWithFlattening,
TensorDataCollatorWithFlatteningDPO
)
from open_instruct.dpo_utils import (
DataCollatorForSeq2SeqDPO,
concatenated_forward
)


try:
import mamba_ssm # noqa
Expand Down Expand Up @@ -102,8 +113,11 @@ class TestPaddingFree:
seqlen = 128
batch_size = 2
dtype = torch.bfloat16
model = None

def get_fa2_model_and_cfg(self, model_name: str, vocab_size: int) -> nn.Module:
if self.model is not None:
return self.model, self.model.config
model_cls = MODEL_CLASSES[model_name]
model_cfg = MODEL_CFGS[model_name]
model_kwargs = MODEL_KWARGS[model_name]
Expand All @@ -116,6 +130,7 @@ def get_fa2_model_and_cfg(self, model_name: str, vocab_size: int) -> nn.Module:
}
)
model = model_cls(cfg).to("cuda", dtype=self.dtype)
self.model = model
return model, cfg

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Padding free tests require CUDA")
Expand All @@ -136,8 +151,6 @@ def test_padding_free(self, model_name: str, loss_type: str) -> None:
model.initialize_weights()
pf_model = deepcopy(model)

inputs = torch.randint(cfg.vocab_size, size=(self.batch_size, self.seqlen), device="cpu")

data = {
0: {
"messages": [
Expand Down Expand Up @@ -213,3 +226,71 @@ def test_padding_free(self, model_name: str, loss_type: str) -> None:
pf_grads = {n: p.grad for n, p in pf_model.named_parameters()}
for k, g in grads.items():
torch.testing.assert_close(g, pf_grads[k])
non_nan_grads.add(k)
print(f"{non_nan_grads=}")
print(f"{nan_grads=}")


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Padding free tests require CUDA")
@pytest.mark.skipif(not flash_attn_available, reason="Padding free requires flash_attn")
@pytest.mark.parametrize("model_name", ["bamba", "llama"])
def test_padding_free_dpo(self, model_name: str) -> None:
if model_name == "bamba" and not mamba_and_causal_conv_available:
pytest.skip("bamba padding-free tests require mamba_ssm and causal_conv1d")
torch.manual_seed(42)

tokenizer = AutoTokenizer.from_pretrained("ibm-ai-platform/Bamba-9B-v2")
tokenizer.add_special_tokens({"pad_token": "<pad>"})
tokenizer.chat_template = CHAT_TEMPLATES["tulu"]
vocab_size = len(tokenizer)

model, cfg = self.get_fa2_model_and_cfg(model_name, vocab_size)
model.initialize_weights()
pf_model = deepcopy(model)

data = {
0: {
"chosen": [
{"role": "user", "content": "Why did the chicken cross the road?"},
{"role": "assistant", "content": "To get to the other side"},
],
"rejected": [
{"role": "user", "content": "Why did the chicken cross the road?"},
{"role": "assistant", "content": "To make friends with a cow. Trying to make a long response to simulate hallucination."},
],
},
1: {
"chosen": [
{"role": "user", "content": "What is one plus two?"},
{"role": "assistant", "content": "The answer is 3"},
],
"rejected": [
{"role": "user", "content": "What is one plus two?"},
{"role": "assistant", "content": "I am a beginner in math. I have not yet learnt addition. The answer is 12"},
]
},
}

tok_data = {k: preference_tulu_tokenize_and_truncate_v1_2(v, tokenizer, max_seq_length=2**30) for k, v in data.items()}

collate_fn = DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=model, padding="longest")
dataloader = DataLoader(tok_data, shuffle=False, collate_fn=collate_fn, batch_size=self.batch_size)

pf_collate_fn = TensorDataCollatorWithFlatteningDPO()
pf_dataloader = DataLoader(tok_data, shuffle=False, collate_fn=pf_collate_fn, batch_size=self.batch_size)

batch = next(iter(dataloader))
pf_batch = next(iter(pf_dataloader))
for b in (batch, pf_batch):
for k in b:
if torch.is_tensor(b[k]):
b[k] = b[k].cuda()

assert batch["chosen_input_ids"].shape[0] == 2
assert pf_batch["chosen_input_ids"].shape[0] == 1
chosen_logps, rejected_logps, _ = concatenated_forward(model, batch, average_log_prob=False)
pf_chosen_logos, pf_rejected_logps, _ = concatenated_forward(pf_model, pf_batch, average_log_prob=False, packing=True)

torch.testing.assert_close(chosen_logps, pf_chosen_logos, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(rejected_logps, pf_rejected_logps, atol=1e-3, rtol=1e-3)

Loading