Skip to content

Commit 47ab034

Browse files
kashifqgallouedeclewtun
authored
[DPO] tokenize and process DPO data via batches (#1914)
* tokenize and process DPO data via batches * use helpers * updated _process_tokens * fixed * incorporate build_tokenized_answer in the _tokenizer * Update trl/trainer/dpo_trainer.py Co-authored-by: Quentin Gallouédec <[email protected]> * Update trl/trainer/dpo_trainer.py Co-authored-by: Quentin Gallouédec <[email protected]> * fix tokenizer for is_vision_model * Update trl/trainer/dpo_trainer.py Co-authored-by: Quentin Gallouédec <[email protected]> * give the _tokenize the tokenizer as well as optional processor * fix tests * add bos and eos tokens * add prompt_pixel_attention_mask * Update trl/trainer/dpo_trainer.py Co-authored-by: Quentin Gallouédec <[email protected]> * truncate by max_length * formatting * fix for enc-dec * For encoder-decoder models, we need to use the prepared decoder_input_ids * add tests for _build_tokenized_answer and _tokenize_feature * check for EOS and BOS tokens * formatting * do not include pixel mask if they are not provided * undo refactor * undo add_bos_token_if_needed change * refactor tokenizer into smaller helpers * add back comments * fix type hints * format * fix t5 tests * args are never optional * move cat to appropriate helper * fix _truncate_tokens * add tests for _truncate_tokens * remove dead code --------- Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: lewtun <[email protected]>
1 parent e755eee commit 47ab034

File tree

3 files changed

+467
-241
lines changed

3 files changed

+467
-241
lines changed

Diff for: tests/test_dpo_trainer.py

+160-3
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,170 @@
3030
)
3131

3232
from trl import DPOConfig, DPOTrainer, FDivergenceType
33+
from trl.trainer.dpo_trainer import _build_tokenized_answer, _truncate_tokens
3334

3435
from .testing_utils import require_bitsandbytes, require_no_wandb, require_peft
3536

3637

38+
class TestBuildTokenizedAnswer(unittest.TestCase):
39+
def setUp(self):
40+
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
41+
self.tokenizer.pad_token = self.tokenizer.eos_token
42+
43+
def test_basic_functionality(self):
44+
prompt = "Hello, how are you?"
45+
answer = "I'm doing well, thank you!"
46+
47+
result = _build_tokenized_answer(prompt, answer, tokenizer=self.tokenizer)
48+
49+
self.assertIn("prompt_input_ids", result)
50+
self.assertIn("prompt_attention_mask", result)
51+
self.assertIn("input_ids", result)
52+
self.assertIn("attention_mask", result)
53+
54+
self.assertEqual(len(result["prompt_input_ids"]), len(result["prompt_attention_mask"]))
55+
self.assertEqual(len(result["input_ids"]), len(result["attention_mask"]))
56+
57+
decoded_prompt = self.tokenizer.decode(result["prompt_input_ids"])
58+
self.assertTrue(prompt in decoded_prompt)
59+
60+
decoded_answer = self.tokenizer.decode(result["input_ids"])
61+
self.assertTrue(answer in decoded_answer)
62+
63+
def test_with_processor(self):
64+
def mock_processor(text, images=None, add_special_tokens=True):
65+
return {"input_ids": torch.tensor([[1, 2, 3]]), "attention_mask": torch.tensor([[1, 1, 1]])}
66+
67+
prompt = "Describe this image:"
68+
answer = "A beautiful sunset over the ocean."
69+
70+
result = _build_tokenized_answer(prompt, answer, processor=mock_processor)
71+
72+
self.assertIn("prompt_input_ids", result)
73+
self.assertIn("prompt_attention_mask", result)
74+
self.assertIn("input_ids", result)
75+
self.assertIn("attention_mask", result)
76+
77+
self.assertEqual(result["prompt_input_ids"], [1, 2, 3])
78+
self.assertEqual(result["prompt_attention_mask"], [1, 1, 1])
79+
80+
def test_token_merging(self):
81+
prompt = "The quick brown"
82+
answer = " fox jumps over the lazy dog."
83+
84+
result = _build_tokenized_answer(prompt, answer, tokenizer=self.tokenizer)
85+
86+
full_text = prompt + answer
87+
full_tokenized = self.tokenizer(full_text, add_special_tokens=False)
88+
89+
self.assertEqual(result["prompt_input_ids"] + result["input_ids"], full_tokenized["input_ids"])
90+
91+
def test_vision_model(self):
92+
def mock_vision_processor(text, images=None, add_special_tokens=True):
93+
return {
94+
"input_ids": torch.tensor([[1, 2, 3]]),
95+
"attention_mask": torch.tensor([[1, 1, 1]]),
96+
"pixel_values": torch.rand(1, 3, 224, 224),
97+
"pixel_attention_mask": torch.ones(1, 224, 224),
98+
}
99+
100+
prompt = "Describe this image:"
101+
answer = "A cat sitting on a windowsill."
102+
103+
result = _build_tokenized_answer(prompt, answer, processor=mock_vision_processor)
104+
105+
self.assertIn("prompt_pixel_values", result)
106+
self.assertIn("prompt_pixel_attention_mask", result)
107+
self.assertTrue(torch.is_tensor(result["prompt_pixel_values"]))
108+
self.assertTrue(torch.is_tensor(result["prompt_pixel_attention_mask"]))
109+
110+
111+
class TestTruncateTokens(unittest.TestCase):
112+
def setUp(self):
113+
with tempfile.TemporaryDirectory() as tmp_dir:
114+
self.args = DPOConfig(
115+
max_length=20, max_prompt_length=10, truncation_mode="keep_start", output_dir=tmp_dir
116+
)
117+
118+
def test_truncate_tokens(self):
119+
chosen_tokens = [
120+
{
121+
"prompt_input_ids": list(range(15)),
122+
"prompt_attention_mask": [1] * 15,
123+
"input_ids": list(range(10)),
124+
"attention_mask": [1] * 10,
125+
}
126+
]
127+
rejected_tokens = [
128+
{
129+
"prompt_input_ids": list(range(15)),
130+
"prompt_attention_mask": [1] * 15,
131+
"input_ids": list(range(12)),
132+
"attention_mask": [1] * 12,
133+
}
134+
]
135+
prompt_tokens = [{"prompt_input_ids": list(range(15)), "prompt_attention_mask": [1] * 15}]
136+
137+
_truncate_tokens(chosen_tokens, rejected_tokens, prompt_tokens, self.args)
138+
139+
# Check if prompt is truncated correctly
140+
self.assertEqual(len(chosen_tokens[0]["prompt_input_ids"]), 10)
141+
self.assertEqual(len(chosen_tokens[0]["prompt_attention_mask"]), 10)
142+
self.assertEqual(len(rejected_tokens[0]["prompt_input_ids"]), 10)
143+
self.assertEqual(len(rejected_tokens[0]["prompt_attention_mask"]), 10)
144+
self.assertEqual(len(prompt_tokens[0]["prompt_input_ids"]), 10)
145+
self.assertEqual(len(prompt_tokens[0]["prompt_attention_mask"]), 10)
146+
147+
# Check if responses are truncated correctly
148+
self.assertEqual(len(chosen_tokens[0]["input_ids"]), 10)
149+
self.assertEqual(len(chosen_tokens[0]["attention_mask"]), 10)
150+
self.assertEqual(len(rejected_tokens[0]["input_ids"]), 10)
151+
self.assertEqual(len(rejected_tokens[0]["attention_mask"]), 10)
152+
153+
def test_truncation_mode_keep_end(self):
154+
self.args.truncation_mode = "keep_end"
155+
chosen_tokens = [
156+
{
157+
"prompt_input_ids": list(range(15)),
158+
"prompt_attention_mask": [1] * 15,
159+
"input_ids": list(range(15, 25)),
160+
"attention_mask": [1] * 10,
161+
}
162+
]
163+
rejected_tokens = [
164+
{
165+
"prompt_input_ids": list(range(15)),
166+
"prompt_attention_mask": [1] * 15,
167+
"input_ids": list(range(15, 28)),
168+
"attention_mask": [1] * 13,
169+
}
170+
]
171+
prompt_tokens = [{"prompt_input_ids": list(range(15)), "prompt_attention_mask": [1] * 15}]
172+
173+
_truncate_tokens(chosen_tokens, rejected_tokens, prompt_tokens, self.args)
174+
175+
# Check if prompt is truncated correctly from the end
176+
self.assertEqual(prompt_tokens[0]["prompt_input_ids"], list(range(5, 15)))
177+
self.assertEqual(prompt_tokens[0]["prompt_attention_mask"], [1] * 10)
178+
179+
# Check if chosen tokens are truncated correctly
180+
self.assertEqual(chosen_tokens[0]["prompt_input_ids"], list(range(5, 15)))
181+
self.assertEqual(chosen_tokens[0]["prompt_attention_mask"], [1] * 10)
182+
self.assertEqual(chosen_tokens[0]["input_ids"], list(range(15, 25)))
183+
self.assertEqual(chosen_tokens[0]["attention_mask"], [1] * 10)
184+
185+
# Check if rejected tokens are truncated correctly
186+
self.assertEqual(rejected_tokens[0]["prompt_input_ids"], list(range(5, 15)))
187+
self.assertEqual(rejected_tokens[0]["prompt_attention_mask"], [1] * 10)
188+
self.assertEqual(rejected_tokens[0]["input_ids"], list(range(15, 25)))
189+
self.assertEqual(rejected_tokens[0]["attention_mask"], [1] * 10)
190+
191+
def test_invalid_truncation_mode(self):
192+
self.args.truncation_mode = "invalid_mode"
193+
with self.assertRaises(ValueError):
194+
_truncate_tokens([], [], [], self.args)
195+
196+
37197
class DPOTrainerTester(unittest.TestCase):
38198
def setUp(self):
39199
self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
@@ -138,9 +298,6 @@ def test_dpo_trainer(self, name, loss_type, pre_compute):
138298
ref_model = self.t5_ref_model
139299
tokenizer = self.t5_tokenizer
140300

141-
if name == "t5":
142-
self.skipTest("For some reason t5 does not compute gradients properly on tiny models")
143-
144301
trainer = DPOTrainer(
145302
model=model,
146303
ref_model=ref_model,

0 commit comments

Comments
 (0)